diff --git a/llama3/base_ir/8b_fp16_nondecomposed.mlir b/llama3/base_ir/8b_fp16_nondecomposed.mlir index 348768f..b7839ce 100644 --- a/llama3/base_ir/8b_fp16_nondecomposed.mlir +++ b/llama3/base_ir/8b_fp16_nondecomposed.mlir @@ -1,5 +1,3 @@ -#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> module @module { util.global private @__auto.token_embd.weight = #stream.parameter.named<"model"::"token_embd.weight"> : tensor<128256x4096xf16> util.global private @__auto.blk.0.attn_norm.weight = #stream.parameter.named<"model"::"blk.0.attn_norm.weight"> : tensor<4096xf32> @@ -293,35885 +291,41149 @@ module @module { util.global private @__auto.output_norm.weight = #stream.parameter.named<"model"::"output_norm.weight"> : tensor<4096xf32> util.global private @__auto.output.weight = #stream.parameter.named<"model"::"output.weight"> : tensor<128256x4096xf16> func.func @prefill_bs4(%arg0: !torch.vtensor<[4,?],si64>, %arg1: !torch.vtensor<[4],si64>, %arg2: !torch.vtensor<[4,?],si64>, %arg3: !torch.tensor<[?,2097152],f16>) -> !torch.vtensor<[4,?,128256],f16> attributes {torch.assume_strict_symbolic_shapes} { + %0 = torch.vtensor.literal(dense<0xFC00> : tensor) : !torch.vtensor<[],f16> %__auto.token_embd.weight = util.global.load @__auto.token_embd.weight : tensor<128256x4096xf16> - %0 = torch_c.from_builtin_tensor %__auto.token_embd.weight : tensor<128256x4096xf16> -> !torch.vtensor<[128256,4096],f16> + %1 = torch_c.from_builtin_tensor %__auto.token_embd.weight : tensor<128256x4096xf16> -> !torch.vtensor<[128256,4096],f16> %__auto.blk.0.attn_norm.weight = util.global.load @__auto.blk.0.attn_norm.weight : tensor<4096xf32> - %1 = torch_c.from_builtin_tensor %__auto.blk.0.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %2 = torch_c.from_builtin_tensor %__auto.blk.0.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.0.attn_q.weight = util.global.load @__auto.blk.0.attn_q.weight : tensor<4096x4096xf16> - %2 = torch_c.from_builtin_tensor %__auto.blk.0.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %3 = torch_c.from_builtin_tensor %__auto.blk.0.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.0.attn_k.weight = util.global.load @__auto.blk.0.attn_k.weight : tensor<1024x4096xf16> - %3 = torch_c.from_builtin_tensor %__auto.blk.0.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %4 = torch_c.from_builtin_tensor %__auto.blk.0.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.0.attn_v.weight = util.global.load @__auto.blk.0.attn_v.weight : tensor<1024x4096xf16> - %4 = torch_c.from_builtin_tensor %__auto.blk.0.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %5 = torch_c.from_builtin_tensor %__auto.blk.0.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.0.attn_output.weight = util.global.load @__auto.blk.0.attn_output.weight : tensor<4096x4096xf16> - %5 = torch_c.from_builtin_tensor %__auto.blk.0.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %6 = torch_c.from_builtin_tensor %__auto.blk.0.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.0.ffn_norm.weight = util.global.load @__auto.blk.0.ffn_norm.weight : tensor<4096xf32> - %6 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %7 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.0.ffn_gate.weight = util.global.load @__auto.blk.0.ffn_gate.weight : tensor<14336x4096xf16> - %7 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %8 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.0.ffn_up.weight = util.global.load @__auto.blk.0.ffn_up.weight : tensor<14336x4096xf16> - %8 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %9 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.0.ffn_down.weight = util.global.load @__auto.blk.0.ffn_down.weight : tensor<4096x14336xf16> - %9 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %10 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.1.attn_norm.weight = util.global.load @__auto.blk.1.attn_norm.weight : tensor<4096xf32> - %10 = torch_c.from_builtin_tensor %__auto.blk.1.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %11 = torch_c.from_builtin_tensor %__auto.blk.1.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.1.attn_q.weight = util.global.load @__auto.blk.1.attn_q.weight : tensor<4096x4096xf16> - %11 = torch_c.from_builtin_tensor %__auto.blk.1.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %12 = torch_c.from_builtin_tensor %__auto.blk.1.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.1.attn_k.weight = util.global.load @__auto.blk.1.attn_k.weight : tensor<1024x4096xf16> - %12 = torch_c.from_builtin_tensor %__auto.blk.1.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %13 = torch_c.from_builtin_tensor %__auto.blk.1.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.1.attn_v.weight = util.global.load @__auto.blk.1.attn_v.weight : tensor<1024x4096xf16> - %13 = torch_c.from_builtin_tensor %__auto.blk.1.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %14 = torch_c.from_builtin_tensor %__auto.blk.1.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.1.attn_output.weight = util.global.load @__auto.blk.1.attn_output.weight : tensor<4096x4096xf16> - %14 = torch_c.from_builtin_tensor %__auto.blk.1.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %15 = torch_c.from_builtin_tensor %__auto.blk.1.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.1.ffn_norm.weight = util.global.load @__auto.blk.1.ffn_norm.weight : tensor<4096xf32> - %15 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %16 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.1.ffn_gate.weight = util.global.load @__auto.blk.1.ffn_gate.weight : tensor<14336x4096xf16> - %16 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %17 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.1.ffn_up.weight = util.global.load @__auto.blk.1.ffn_up.weight : tensor<14336x4096xf16> - %17 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %18 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.1.ffn_down.weight = util.global.load @__auto.blk.1.ffn_down.weight : tensor<4096x14336xf16> - %18 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %19 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.2.attn_norm.weight = util.global.load @__auto.blk.2.attn_norm.weight : tensor<4096xf32> - %19 = torch_c.from_builtin_tensor %__auto.blk.2.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %20 = torch_c.from_builtin_tensor %__auto.blk.2.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.2.attn_q.weight = util.global.load @__auto.blk.2.attn_q.weight : tensor<4096x4096xf16> - %20 = torch_c.from_builtin_tensor %__auto.blk.2.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %21 = torch_c.from_builtin_tensor %__auto.blk.2.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.2.attn_k.weight = util.global.load @__auto.blk.2.attn_k.weight : tensor<1024x4096xf16> - %21 = torch_c.from_builtin_tensor %__auto.blk.2.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %22 = torch_c.from_builtin_tensor %__auto.blk.2.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.2.attn_v.weight = util.global.load @__auto.blk.2.attn_v.weight : tensor<1024x4096xf16> - %22 = torch_c.from_builtin_tensor %__auto.blk.2.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %23 = torch_c.from_builtin_tensor %__auto.blk.2.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.2.attn_output.weight = util.global.load @__auto.blk.2.attn_output.weight : tensor<4096x4096xf16> - %23 = torch_c.from_builtin_tensor %__auto.blk.2.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %24 = torch_c.from_builtin_tensor %__auto.blk.2.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.2.ffn_norm.weight = util.global.load @__auto.blk.2.ffn_norm.weight : tensor<4096xf32> - %24 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %25 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.2.ffn_gate.weight = util.global.load @__auto.blk.2.ffn_gate.weight : tensor<14336x4096xf16> - %25 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %26 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.2.ffn_up.weight = util.global.load @__auto.blk.2.ffn_up.weight : tensor<14336x4096xf16> - %26 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %27 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.2.ffn_down.weight = util.global.load @__auto.blk.2.ffn_down.weight : tensor<4096x14336xf16> - %27 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %28 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.3.attn_norm.weight = util.global.load @__auto.blk.3.attn_norm.weight : tensor<4096xf32> - %28 = torch_c.from_builtin_tensor %__auto.blk.3.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %29 = torch_c.from_builtin_tensor %__auto.blk.3.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.3.attn_q.weight = util.global.load @__auto.blk.3.attn_q.weight : tensor<4096x4096xf16> - %29 = torch_c.from_builtin_tensor %__auto.blk.3.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %30 = torch_c.from_builtin_tensor %__auto.blk.3.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.3.attn_k.weight = util.global.load @__auto.blk.3.attn_k.weight : tensor<1024x4096xf16> - %30 = torch_c.from_builtin_tensor %__auto.blk.3.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %31 = torch_c.from_builtin_tensor %__auto.blk.3.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.3.attn_v.weight = util.global.load @__auto.blk.3.attn_v.weight : tensor<1024x4096xf16> - %31 = torch_c.from_builtin_tensor %__auto.blk.3.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %32 = torch_c.from_builtin_tensor %__auto.blk.3.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.3.attn_output.weight = util.global.load @__auto.blk.3.attn_output.weight : tensor<4096x4096xf16> - %32 = torch_c.from_builtin_tensor %__auto.blk.3.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %33 = torch_c.from_builtin_tensor %__auto.blk.3.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.3.ffn_norm.weight = util.global.load @__auto.blk.3.ffn_norm.weight : tensor<4096xf32> - %33 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %34 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.3.ffn_gate.weight = util.global.load @__auto.blk.3.ffn_gate.weight : tensor<14336x4096xf16> - %34 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %35 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.3.ffn_up.weight = util.global.load @__auto.blk.3.ffn_up.weight : tensor<14336x4096xf16> - %35 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %36 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.3.ffn_down.weight = util.global.load @__auto.blk.3.ffn_down.weight : tensor<4096x14336xf16> - %36 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %37 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.4.attn_norm.weight = util.global.load @__auto.blk.4.attn_norm.weight : tensor<4096xf32> - %37 = torch_c.from_builtin_tensor %__auto.blk.4.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %38 = torch_c.from_builtin_tensor %__auto.blk.4.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.4.attn_q.weight = util.global.load @__auto.blk.4.attn_q.weight : tensor<4096x4096xf16> - %38 = torch_c.from_builtin_tensor %__auto.blk.4.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %39 = torch_c.from_builtin_tensor %__auto.blk.4.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.4.attn_k.weight = util.global.load @__auto.blk.4.attn_k.weight : tensor<1024x4096xf16> - %39 = torch_c.from_builtin_tensor %__auto.blk.4.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %40 = torch_c.from_builtin_tensor %__auto.blk.4.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.4.attn_v.weight = util.global.load @__auto.blk.4.attn_v.weight : tensor<1024x4096xf16> - %40 = torch_c.from_builtin_tensor %__auto.blk.4.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %41 = torch_c.from_builtin_tensor %__auto.blk.4.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.4.attn_output.weight = util.global.load @__auto.blk.4.attn_output.weight : tensor<4096x4096xf16> - %41 = torch_c.from_builtin_tensor %__auto.blk.4.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %42 = torch_c.from_builtin_tensor %__auto.blk.4.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.4.ffn_norm.weight = util.global.load @__auto.blk.4.ffn_norm.weight : tensor<4096xf32> - %42 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %43 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.4.ffn_gate.weight = util.global.load @__auto.blk.4.ffn_gate.weight : tensor<14336x4096xf16> - %43 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %44 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.4.ffn_up.weight = util.global.load @__auto.blk.4.ffn_up.weight : tensor<14336x4096xf16> - %44 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %45 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.4.ffn_down.weight = util.global.load @__auto.blk.4.ffn_down.weight : tensor<4096x14336xf16> - %45 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %46 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.5.attn_norm.weight = util.global.load @__auto.blk.5.attn_norm.weight : tensor<4096xf32> - %46 = torch_c.from_builtin_tensor %__auto.blk.5.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %47 = torch_c.from_builtin_tensor %__auto.blk.5.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.5.attn_q.weight = util.global.load @__auto.blk.5.attn_q.weight : tensor<4096x4096xf16> - %47 = torch_c.from_builtin_tensor %__auto.blk.5.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %48 = torch_c.from_builtin_tensor %__auto.blk.5.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.5.attn_k.weight = util.global.load @__auto.blk.5.attn_k.weight : tensor<1024x4096xf16> - %48 = torch_c.from_builtin_tensor %__auto.blk.5.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %49 = torch_c.from_builtin_tensor %__auto.blk.5.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.5.attn_v.weight = util.global.load @__auto.blk.5.attn_v.weight : tensor<1024x4096xf16> - %49 = torch_c.from_builtin_tensor %__auto.blk.5.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %50 = torch_c.from_builtin_tensor %__auto.blk.5.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.5.attn_output.weight = util.global.load @__auto.blk.5.attn_output.weight : tensor<4096x4096xf16> - %50 = torch_c.from_builtin_tensor %__auto.blk.5.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %51 = torch_c.from_builtin_tensor %__auto.blk.5.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.5.ffn_norm.weight = util.global.load @__auto.blk.5.ffn_norm.weight : tensor<4096xf32> - %51 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %52 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.5.ffn_gate.weight = util.global.load @__auto.blk.5.ffn_gate.weight : tensor<14336x4096xf16> - %52 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %53 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.5.ffn_up.weight = util.global.load @__auto.blk.5.ffn_up.weight : tensor<14336x4096xf16> - %53 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %54 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.5.ffn_down.weight = util.global.load @__auto.blk.5.ffn_down.weight : tensor<4096x14336xf16> - %54 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %55 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.6.attn_norm.weight = util.global.load @__auto.blk.6.attn_norm.weight : tensor<4096xf32> - %55 = torch_c.from_builtin_tensor %__auto.blk.6.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %56 = torch_c.from_builtin_tensor %__auto.blk.6.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.6.attn_q.weight = util.global.load @__auto.blk.6.attn_q.weight : tensor<4096x4096xf16> - %56 = torch_c.from_builtin_tensor %__auto.blk.6.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %57 = torch_c.from_builtin_tensor %__auto.blk.6.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.6.attn_k.weight = util.global.load @__auto.blk.6.attn_k.weight : tensor<1024x4096xf16> - %57 = torch_c.from_builtin_tensor %__auto.blk.6.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %58 = torch_c.from_builtin_tensor %__auto.blk.6.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.6.attn_v.weight = util.global.load @__auto.blk.6.attn_v.weight : tensor<1024x4096xf16> - %58 = torch_c.from_builtin_tensor %__auto.blk.6.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %59 = torch_c.from_builtin_tensor %__auto.blk.6.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.6.attn_output.weight = util.global.load @__auto.blk.6.attn_output.weight : tensor<4096x4096xf16> - %59 = torch_c.from_builtin_tensor %__auto.blk.6.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %60 = torch_c.from_builtin_tensor %__auto.blk.6.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.6.ffn_norm.weight = util.global.load @__auto.blk.6.ffn_norm.weight : tensor<4096xf32> - %60 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %61 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.6.ffn_gate.weight = util.global.load @__auto.blk.6.ffn_gate.weight : tensor<14336x4096xf16> - %61 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %62 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.6.ffn_up.weight = util.global.load @__auto.blk.6.ffn_up.weight : tensor<14336x4096xf16> - %62 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %63 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.6.ffn_down.weight = util.global.load @__auto.blk.6.ffn_down.weight : tensor<4096x14336xf16> - %63 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %64 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.7.attn_norm.weight = util.global.load @__auto.blk.7.attn_norm.weight : tensor<4096xf32> - %64 = torch_c.from_builtin_tensor %__auto.blk.7.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %65 = torch_c.from_builtin_tensor %__auto.blk.7.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.7.attn_q.weight = util.global.load @__auto.blk.7.attn_q.weight : tensor<4096x4096xf16> - %65 = torch_c.from_builtin_tensor %__auto.blk.7.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %66 = torch_c.from_builtin_tensor %__auto.blk.7.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.7.attn_k.weight = util.global.load @__auto.blk.7.attn_k.weight : tensor<1024x4096xf16> - %66 = torch_c.from_builtin_tensor %__auto.blk.7.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %67 = torch_c.from_builtin_tensor %__auto.blk.7.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.7.attn_v.weight = util.global.load @__auto.blk.7.attn_v.weight : tensor<1024x4096xf16> - %67 = torch_c.from_builtin_tensor %__auto.blk.7.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %68 = torch_c.from_builtin_tensor %__auto.blk.7.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.7.attn_output.weight = util.global.load @__auto.blk.7.attn_output.weight : tensor<4096x4096xf16> - %68 = torch_c.from_builtin_tensor %__auto.blk.7.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %69 = torch_c.from_builtin_tensor %__auto.blk.7.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.7.ffn_norm.weight = util.global.load @__auto.blk.7.ffn_norm.weight : tensor<4096xf32> - %69 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %70 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.7.ffn_gate.weight = util.global.load @__auto.blk.7.ffn_gate.weight : tensor<14336x4096xf16> - %70 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %71 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.7.ffn_up.weight = util.global.load @__auto.blk.7.ffn_up.weight : tensor<14336x4096xf16> - %71 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %72 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.7.ffn_down.weight = util.global.load @__auto.blk.7.ffn_down.weight : tensor<4096x14336xf16> - %72 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %73 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.8.attn_norm.weight = util.global.load @__auto.blk.8.attn_norm.weight : tensor<4096xf32> - %73 = torch_c.from_builtin_tensor %__auto.blk.8.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %74 = torch_c.from_builtin_tensor %__auto.blk.8.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.8.attn_q.weight = util.global.load @__auto.blk.8.attn_q.weight : tensor<4096x4096xf16> - %74 = torch_c.from_builtin_tensor %__auto.blk.8.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %75 = torch_c.from_builtin_tensor %__auto.blk.8.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.8.attn_k.weight = util.global.load @__auto.blk.8.attn_k.weight : tensor<1024x4096xf16> - %75 = torch_c.from_builtin_tensor %__auto.blk.8.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %76 = torch_c.from_builtin_tensor %__auto.blk.8.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.8.attn_v.weight = util.global.load @__auto.blk.8.attn_v.weight : tensor<1024x4096xf16> - %76 = torch_c.from_builtin_tensor %__auto.blk.8.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %77 = torch_c.from_builtin_tensor %__auto.blk.8.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.8.attn_output.weight = util.global.load @__auto.blk.8.attn_output.weight : tensor<4096x4096xf16> - %77 = torch_c.from_builtin_tensor %__auto.blk.8.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %78 = torch_c.from_builtin_tensor %__auto.blk.8.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.8.ffn_norm.weight = util.global.load @__auto.blk.8.ffn_norm.weight : tensor<4096xf32> - %78 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %79 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.8.ffn_gate.weight = util.global.load @__auto.blk.8.ffn_gate.weight : tensor<14336x4096xf16> - %79 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %80 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.8.ffn_up.weight = util.global.load @__auto.blk.8.ffn_up.weight : tensor<14336x4096xf16> - %80 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %81 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.8.ffn_down.weight = util.global.load @__auto.blk.8.ffn_down.weight : tensor<4096x14336xf16> - %81 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %82 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.9.attn_norm.weight = util.global.load @__auto.blk.9.attn_norm.weight : tensor<4096xf32> - %82 = torch_c.from_builtin_tensor %__auto.blk.9.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %83 = torch_c.from_builtin_tensor %__auto.blk.9.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.9.attn_q.weight = util.global.load @__auto.blk.9.attn_q.weight : tensor<4096x4096xf16> - %83 = torch_c.from_builtin_tensor %__auto.blk.9.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %84 = torch_c.from_builtin_tensor %__auto.blk.9.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.9.attn_k.weight = util.global.load @__auto.blk.9.attn_k.weight : tensor<1024x4096xf16> - %84 = torch_c.from_builtin_tensor %__auto.blk.9.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %85 = torch_c.from_builtin_tensor %__auto.blk.9.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.9.attn_v.weight = util.global.load @__auto.blk.9.attn_v.weight : tensor<1024x4096xf16> - %85 = torch_c.from_builtin_tensor %__auto.blk.9.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %86 = torch_c.from_builtin_tensor %__auto.blk.9.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.9.attn_output.weight = util.global.load @__auto.blk.9.attn_output.weight : tensor<4096x4096xf16> - %86 = torch_c.from_builtin_tensor %__auto.blk.9.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %87 = torch_c.from_builtin_tensor %__auto.blk.9.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.9.ffn_norm.weight = util.global.load @__auto.blk.9.ffn_norm.weight : tensor<4096xf32> - %87 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %88 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.9.ffn_gate.weight = util.global.load @__auto.blk.9.ffn_gate.weight : tensor<14336x4096xf16> - %88 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %89 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.9.ffn_up.weight = util.global.load @__auto.blk.9.ffn_up.weight : tensor<14336x4096xf16> - %89 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %90 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.9.ffn_down.weight = util.global.load @__auto.blk.9.ffn_down.weight : tensor<4096x14336xf16> - %90 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %91 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.10.attn_norm.weight = util.global.load @__auto.blk.10.attn_norm.weight : tensor<4096xf32> - %91 = torch_c.from_builtin_tensor %__auto.blk.10.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %92 = torch_c.from_builtin_tensor %__auto.blk.10.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.10.attn_q.weight = util.global.load @__auto.blk.10.attn_q.weight : tensor<4096x4096xf16> - %92 = torch_c.from_builtin_tensor %__auto.blk.10.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %93 = torch_c.from_builtin_tensor %__auto.blk.10.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.10.attn_k.weight = util.global.load @__auto.blk.10.attn_k.weight : tensor<1024x4096xf16> - %93 = torch_c.from_builtin_tensor %__auto.blk.10.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %94 = torch_c.from_builtin_tensor %__auto.blk.10.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.10.attn_v.weight = util.global.load @__auto.blk.10.attn_v.weight : tensor<1024x4096xf16> - %94 = torch_c.from_builtin_tensor %__auto.blk.10.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %95 = torch_c.from_builtin_tensor %__auto.blk.10.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.10.attn_output.weight = util.global.load @__auto.blk.10.attn_output.weight : tensor<4096x4096xf16> - %95 = torch_c.from_builtin_tensor %__auto.blk.10.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %96 = torch_c.from_builtin_tensor %__auto.blk.10.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.10.ffn_norm.weight = util.global.load @__auto.blk.10.ffn_norm.weight : tensor<4096xf32> - %96 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %97 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.10.ffn_gate.weight = util.global.load @__auto.blk.10.ffn_gate.weight : tensor<14336x4096xf16> - %97 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %98 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.10.ffn_up.weight = util.global.load @__auto.blk.10.ffn_up.weight : tensor<14336x4096xf16> - %98 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %99 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.10.ffn_down.weight = util.global.load @__auto.blk.10.ffn_down.weight : tensor<4096x14336xf16> - %99 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %100 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.11.attn_norm.weight = util.global.load @__auto.blk.11.attn_norm.weight : tensor<4096xf32> - %100 = torch_c.from_builtin_tensor %__auto.blk.11.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %101 = torch_c.from_builtin_tensor %__auto.blk.11.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.11.attn_q.weight = util.global.load @__auto.blk.11.attn_q.weight : tensor<4096x4096xf16> - %101 = torch_c.from_builtin_tensor %__auto.blk.11.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %102 = torch_c.from_builtin_tensor %__auto.blk.11.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.11.attn_k.weight = util.global.load @__auto.blk.11.attn_k.weight : tensor<1024x4096xf16> - %102 = torch_c.from_builtin_tensor %__auto.blk.11.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %103 = torch_c.from_builtin_tensor %__auto.blk.11.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.11.attn_v.weight = util.global.load @__auto.blk.11.attn_v.weight : tensor<1024x4096xf16> - %103 = torch_c.from_builtin_tensor %__auto.blk.11.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %104 = torch_c.from_builtin_tensor %__auto.blk.11.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.11.attn_output.weight = util.global.load @__auto.blk.11.attn_output.weight : tensor<4096x4096xf16> - %104 = torch_c.from_builtin_tensor %__auto.blk.11.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %105 = torch_c.from_builtin_tensor %__auto.blk.11.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.11.ffn_norm.weight = util.global.load @__auto.blk.11.ffn_norm.weight : tensor<4096xf32> - %105 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %106 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.11.ffn_gate.weight = util.global.load @__auto.blk.11.ffn_gate.weight : tensor<14336x4096xf16> - %106 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %107 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.11.ffn_up.weight = util.global.load @__auto.blk.11.ffn_up.weight : tensor<14336x4096xf16> - %107 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %108 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.11.ffn_down.weight = util.global.load @__auto.blk.11.ffn_down.weight : tensor<4096x14336xf16> - %108 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %109 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.12.attn_norm.weight = util.global.load @__auto.blk.12.attn_norm.weight : tensor<4096xf32> - %109 = torch_c.from_builtin_tensor %__auto.blk.12.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %110 = torch_c.from_builtin_tensor %__auto.blk.12.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.12.attn_q.weight = util.global.load @__auto.blk.12.attn_q.weight : tensor<4096x4096xf16> - %110 = torch_c.from_builtin_tensor %__auto.blk.12.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %111 = torch_c.from_builtin_tensor %__auto.blk.12.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.12.attn_k.weight = util.global.load @__auto.blk.12.attn_k.weight : tensor<1024x4096xf16> - %111 = torch_c.from_builtin_tensor %__auto.blk.12.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %112 = torch_c.from_builtin_tensor %__auto.blk.12.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.12.attn_v.weight = util.global.load @__auto.blk.12.attn_v.weight : tensor<1024x4096xf16> - %112 = torch_c.from_builtin_tensor %__auto.blk.12.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %113 = torch_c.from_builtin_tensor %__auto.blk.12.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.12.attn_output.weight = util.global.load @__auto.blk.12.attn_output.weight : tensor<4096x4096xf16> - %113 = torch_c.from_builtin_tensor %__auto.blk.12.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %114 = torch_c.from_builtin_tensor %__auto.blk.12.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.12.ffn_norm.weight = util.global.load @__auto.blk.12.ffn_norm.weight : tensor<4096xf32> - %114 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %115 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.12.ffn_gate.weight = util.global.load @__auto.blk.12.ffn_gate.weight : tensor<14336x4096xf16> - %115 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %116 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.12.ffn_up.weight = util.global.load @__auto.blk.12.ffn_up.weight : tensor<14336x4096xf16> - %116 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %117 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.12.ffn_down.weight = util.global.load @__auto.blk.12.ffn_down.weight : tensor<4096x14336xf16> - %117 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %118 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.13.attn_norm.weight = util.global.load @__auto.blk.13.attn_norm.weight : tensor<4096xf32> - %118 = torch_c.from_builtin_tensor %__auto.blk.13.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %119 = torch_c.from_builtin_tensor %__auto.blk.13.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.13.attn_q.weight = util.global.load @__auto.blk.13.attn_q.weight : tensor<4096x4096xf16> - %119 = torch_c.from_builtin_tensor %__auto.blk.13.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %120 = torch_c.from_builtin_tensor %__auto.blk.13.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.13.attn_k.weight = util.global.load @__auto.blk.13.attn_k.weight : tensor<1024x4096xf16> - %120 = torch_c.from_builtin_tensor %__auto.blk.13.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %121 = torch_c.from_builtin_tensor %__auto.blk.13.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.13.attn_v.weight = util.global.load @__auto.blk.13.attn_v.weight : tensor<1024x4096xf16> - %121 = torch_c.from_builtin_tensor %__auto.blk.13.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %122 = torch_c.from_builtin_tensor %__auto.blk.13.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.13.attn_output.weight = util.global.load @__auto.blk.13.attn_output.weight : tensor<4096x4096xf16> - %122 = torch_c.from_builtin_tensor %__auto.blk.13.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %123 = torch_c.from_builtin_tensor %__auto.blk.13.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.13.ffn_norm.weight = util.global.load @__auto.blk.13.ffn_norm.weight : tensor<4096xf32> - %123 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %124 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.13.ffn_gate.weight = util.global.load @__auto.blk.13.ffn_gate.weight : tensor<14336x4096xf16> - %124 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %125 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.13.ffn_up.weight = util.global.load @__auto.blk.13.ffn_up.weight : tensor<14336x4096xf16> - %125 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %126 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.13.ffn_down.weight = util.global.load @__auto.blk.13.ffn_down.weight : tensor<4096x14336xf16> - %126 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %127 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.14.attn_norm.weight = util.global.load @__auto.blk.14.attn_norm.weight : tensor<4096xf32> - %127 = torch_c.from_builtin_tensor %__auto.blk.14.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %128 = torch_c.from_builtin_tensor %__auto.blk.14.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.14.attn_q.weight = util.global.load @__auto.blk.14.attn_q.weight : tensor<4096x4096xf16> - %128 = torch_c.from_builtin_tensor %__auto.blk.14.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %129 = torch_c.from_builtin_tensor %__auto.blk.14.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.14.attn_k.weight = util.global.load @__auto.blk.14.attn_k.weight : tensor<1024x4096xf16> - %129 = torch_c.from_builtin_tensor %__auto.blk.14.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %130 = torch_c.from_builtin_tensor %__auto.blk.14.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.14.attn_v.weight = util.global.load @__auto.blk.14.attn_v.weight : tensor<1024x4096xf16> - %130 = torch_c.from_builtin_tensor %__auto.blk.14.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %131 = torch_c.from_builtin_tensor %__auto.blk.14.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.14.attn_output.weight = util.global.load @__auto.blk.14.attn_output.weight : tensor<4096x4096xf16> - %131 = torch_c.from_builtin_tensor %__auto.blk.14.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %132 = torch_c.from_builtin_tensor %__auto.blk.14.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.14.ffn_norm.weight = util.global.load @__auto.blk.14.ffn_norm.weight : tensor<4096xf32> - %132 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %133 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.14.ffn_gate.weight = util.global.load @__auto.blk.14.ffn_gate.weight : tensor<14336x4096xf16> - %133 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %134 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.14.ffn_up.weight = util.global.load @__auto.blk.14.ffn_up.weight : tensor<14336x4096xf16> - %134 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %135 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.14.ffn_down.weight = util.global.load @__auto.blk.14.ffn_down.weight : tensor<4096x14336xf16> - %135 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %136 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.15.attn_norm.weight = util.global.load @__auto.blk.15.attn_norm.weight : tensor<4096xf32> - %136 = torch_c.from_builtin_tensor %__auto.blk.15.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %137 = torch_c.from_builtin_tensor %__auto.blk.15.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.15.attn_q.weight = util.global.load @__auto.blk.15.attn_q.weight : tensor<4096x4096xf16> - %137 = torch_c.from_builtin_tensor %__auto.blk.15.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %138 = torch_c.from_builtin_tensor %__auto.blk.15.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.15.attn_k.weight = util.global.load @__auto.blk.15.attn_k.weight : tensor<1024x4096xf16> - %138 = torch_c.from_builtin_tensor %__auto.blk.15.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %139 = torch_c.from_builtin_tensor %__auto.blk.15.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.15.attn_v.weight = util.global.load @__auto.blk.15.attn_v.weight : tensor<1024x4096xf16> - %139 = torch_c.from_builtin_tensor %__auto.blk.15.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %140 = torch_c.from_builtin_tensor %__auto.blk.15.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.15.attn_output.weight = util.global.load @__auto.blk.15.attn_output.weight : tensor<4096x4096xf16> - %140 = torch_c.from_builtin_tensor %__auto.blk.15.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %141 = torch_c.from_builtin_tensor %__auto.blk.15.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.15.ffn_norm.weight = util.global.load @__auto.blk.15.ffn_norm.weight : tensor<4096xf32> - %141 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %142 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.15.ffn_gate.weight = util.global.load @__auto.blk.15.ffn_gate.weight : tensor<14336x4096xf16> - %142 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %143 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.15.ffn_up.weight = util.global.load @__auto.blk.15.ffn_up.weight : tensor<14336x4096xf16> - %143 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %144 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.15.ffn_down.weight = util.global.load @__auto.blk.15.ffn_down.weight : tensor<4096x14336xf16> - %144 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %145 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.16.attn_norm.weight = util.global.load @__auto.blk.16.attn_norm.weight : tensor<4096xf32> - %145 = torch_c.from_builtin_tensor %__auto.blk.16.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %146 = torch_c.from_builtin_tensor %__auto.blk.16.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.16.attn_q.weight = util.global.load @__auto.blk.16.attn_q.weight : tensor<4096x4096xf16> - %146 = torch_c.from_builtin_tensor %__auto.blk.16.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %147 = torch_c.from_builtin_tensor %__auto.blk.16.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.16.attn_k.weight = util.global.load @__auto.blk.16.attn_k.weight : tensor<1024x4096xf16> - %147 = torch_c.from_builtin_tensor %__auto.blk.16.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %148 = torch_c.from_builtin_tensor %__auto.blk.16.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.16.attn_v.weight = util.global.load @__auto.blk.16.attn_v.weight : tensor<1024x4096xf16> - %148 = torch_c.from_builtin_tensor %__auto.blk.16.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %149 = torch_c.from_builtin_tensor %__auto.blk.16.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.16.attn_output.weight = util.global.load @__auto.blk.16.attn_output.weight : tensor<4096x4096xf16> - %149 = torch_c.from_builtin_tensor %__auto.blk.16.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %150 = torch_c.from_builtin_tensor %__auto.blk.16.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.16.ffn_norm.weight = util.global.load @__auto.blk.16.ffn_norm.weight : tensor<4096xf32> - %150 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %151 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.16.ffn_gate.weight = util.global.load @__auto.blk.16.ffn_gate.weight : tensor<14336x4096xf16> - %151 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %152 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.16.ffn_up.weight = util.global.load @__auto.blk.16.ffn_up.weight : tensor<14336x4096xf16> - %152 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %153 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.16.ffn_down.weight = util.global.load @__auto.blk.16.ffn_down.weight : tensor<4096x14336xf16> - %153 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %154 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.17.attn_norm.weight = util.global.load @__auto.blk.17.attn_norm.weight : tensor<4096xf32> - %154 = torch_c.from_builtin_tensor %__auto.blk.17.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %155 = torch_c.from_builtin_tensor %__auto.blk.17.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.17.attn_q.weight = util.global.load @__auto.blk.17.attn_q.weight : tensor<4096x4096xf16> - %155 = torch_c.from_builtin_tensor %__auto.blk.17.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %156 = torch_c.from_builtin_tensor %__auto.blk.17.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.17.attn_k.weight = util.global.load @__auto.blk.17.attn_k.weight : tensor<1024x4096xf16> - %156 = torch_c.from_builtin_tensor %__auto.blk.17.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %157 = torch_c.from_builtin_tensor %__auto.blk.17.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.17.attn_v.weight = util.global.load @__auto.blk.17.attn_v.weight : tensor<1024x4096xf16> - %157 = torch_c.from_builtin_tensor %__auto.blk.17.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %158 = torch_c.from_builtin_tensor %__auto.blk.17.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.17.attn_output.weight = util.global.load @__auto.blk.17.attn_output.weight : tensor<4096x4096xf16> - %158 = torch_c.from_builtin_tensor %__auto.blk.17.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %159 = torch_c.from_builtin_tensor %__auto.blk.17.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.17.ffn_norm.weight = util.global.load @__auto.blk.17.ffn_norm.weight : tensor<4096xf32> - %159 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %160 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.17.ffn_gate.weight = util.global.load @__auto.blk.17.ffn_gate.weight : tensor<14336x4096xf16> - %160 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %161 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.17.ffn_up.weight = util.global.load @__auto.blk.17.ffn_up.weight : tensor<14336x4096xf16> - %161 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %162 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.17.ffn_down.weight = util.global.load @__auto.blk.17.ffn_down.weight : tensor<4096x14336xf16> - %162 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %163 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.18.attn_norm.weight = util.global.load @__auto.blk.18.attn_norm.weight : tensor<4096xf32> - %163 = torch_c.from_builtin_tensor %__auto.blk.18.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %164 = torch_c.from_builtin_tensor %__auto.blk.18.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.18.attn_q.weight = util.global.load @__auto.blk.18.attn_q.weight : tensor<4096x4096xf16> - %164 = torch_c.from_builtin_tensor %__auto.blk.18.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %165 = torch_c.from_builtin_tensor %__auto.blk.18.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.18.attn_k.weight = util.global.load @__auto.blk.18.attn_k.weight : tensor<1024x4096xf16> - %165 = torch_c.from_builtin_tensor %__auto.blk.18.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %166 = torch_c.from_builtin_tensor %__auto.blk.18.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.18.attn_v.weight = util.global.load @__auto.blk.18.attn_v.weight : tensor<1024x4096xf16> - %166 = torch_c.from_builtin_tensor %__auto.blk.18.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %167 = torch_c.from_builtin_tensor %__auto.blk.18.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.18.attn_output.weight = util.global.load @__auto.blk.18.attn_output.weight : tensor<4096x4096xf16> - %167 = torch_c.from_builtin_tensor %__auto.blk.18.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %168 = torch_c.from_builtin_tensor %__auto.blk.18.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.18.ffn_norm.weight = util.global.load @__auto.blk.18.ffn_norm.weight : tensor<4096xf32> - %168 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %169 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.18.ffn_gate.weight = util.global.load @__auto.blk.18.ffn_gate.weight : tensor<14336x4096xf16> - %169 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %170 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.18.ffn_up.weight = util.global.load @__auto.blk.18.ffn_up.weight : tensor<14336x4096xf16> - %170 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %171 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.18.ffn_down.weight = util.global.load @__auto.blk.18.ffn_down.weight : tensor<4096x14336xf16> - %171 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %172 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.19.attn_norm.weight = util.global.load @__auto.blk.19.attn_norm.weight : tensor<4096xf32> - %172 = torch_c.from_builtin_tensor %__auto.blk.19.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %173 = torch_c.from_builtin_tensor %__auto.blk.19.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.19.attn_q.weight = util.global.load @__auto.blk.19.attn_q.weight : tensor<4096x4096xf16> - %173 = torch_c.from_builtin_tensor %__auto.blk.19.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %174 = torch_c.from_builtin_tensor %__auto.blk.19.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.19.attn_k.weight = util.global.load @__auto.blk.19.attn_k.weight : tensor<1024x4096xf16> - %174 = torch_c.from_builtin_tensor %__auto.blk.19.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %175 = torch_c.from_builtin_tensor %__auto.blk.19.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.19.attn_v.weight = util.global.load @__auto.blk.19.attn_v.weight : tensor<1024x4096xf16> - %175 = torch_c.from_builtin_tensor %__auto.blk.19.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %176 = torch_c.from_builtin_tensor %__auto.blk.19.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.19.attn_output.weight = util.global.load @__auto.blk.19.attn_output.weight : tensor<4096x4096xf16> - %176 = torch_c.from_builtin_tensor %__auto.blk.19.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %177 = torch_c.from_builtin_tensor %__auto.blk.19.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.19.ffn_norm.weight = util.global.load @__auto.blk.19.ffn_norm.weight : tensor<4096xf32> - %177 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %178 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.19.ffn_gate.weight = util.global.load @__auto.blk.19.ffn_gate.weight : tensor<14336x4096xf16> - %178 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %179 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.19.ffn_up.weight = util.global.load @__auto.blk.19.ffn_up.weight : tensor<14336x4096xf16> - %179 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %180 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.19.ffn_down.weight = util.global.load @__auto.blk.19.ffn_down.weight : tensor<4096x14336xf16> - %180 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %181 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.20.attn_norm.weight = util.global.load @__auto.blk.20.attn_norm.weight : tensor<4096xf32> - %181 = torch_c.from_builtin_tensor %__auto.blk.20.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %182 = torch_c.from_builtin_tensor %__auto.blk.20.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.20.attn_q.weight = util.global.load @__auto.blk.20.attn_q.weight : tensor<4096x4096xf16> - %182 = torch_c.from_builtin_tensor %__auto.blk.20.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %183 = torch_c.from_builtin_tensor %__auto.blk.20.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.20.attn_k.weight = util.global.load @__auto.blk.20.attn_k.weight : tensor<1024x4096xf16> - %183 = torch_c.from_builtin_tensor %__auto.blk.20.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %184 = torch_c.from_builtin_tensor %__auto.blk.20.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.20.attn_v.weight = util.global.load @__auto.blk.20.attn_v.weight : tensor<1024x4096xf16> - %184 = torch_c.from_builtin_tensor %__auto.blk.20.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %185 = torch_c.from_builtin_tensor %__auto.blk.20.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.20.attn_output.weight = util.global.load @__auto.blk.20.attn_output.weight : tensor<4096x4096xf16> - %185 = torch_c.from_builtin_tensor %__auto.blk.20.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %186 = torch_c.from_builtin_tensor %__auto.blk.20.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.20.ffn_norm.weight = util.global.load @__auto.blk.20.ffn_norm.weight : tensor<4096xf32> - %186 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %187 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.20.ffn_gate.weight = util.global.load @__auto.blk.20.ffn_gate.weight : tensor<14336x4096xf16> - %187 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %188 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.20.ffn_up.weight = util.global.load @__auto.blk.20.ffn_up.weight : tensor<14336x4096xf16> - %188 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %189 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.20.ffn_down.weight = util.global.load @__auto.blk.20.ffn_down.weight : tensor<4096x14336xf16> - %189 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %190 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.21.attn_norm.weight = util.global.load @__auto.blk.21.attn_norm.weight : tensor<4096xf32> - %190 = torch_c.from_builtin_tensor %__auto.blk.21.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %191 = torch_c.from_builtin_tensor %__auto.blk.21.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.21.attn_q.weight = util.global.load @__auto.blk.21.attn_q.weight : tensor<4096x4096xf16> - %191 = torch_c.from_builtin_tensor %__auto.blk.21.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %192 = torch_c.from_builtin_tensor %__auto.blk.21.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.21.attn_k.weight = util.global.load @__auto.blk.21.attn_k.weight : tensor<1024x4096xf16> - %192 = torch_c.from_builtin_tensor %__auto.blk.21.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %193 = torch_c.from_builtin_tensor %__auto.blk.21.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.21.attn_v.weight = util.global.load @__auto.blk.21.attn_v.weight : tensor<1024x4096xf16> - %193 = torch_c.from_builtin_tensor %__auto.blk.21.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %194 = torch_c.from_builtin_tensor %__auto.blk.21.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.21.attn_output.weight = util.global.load @__auto.blk.21.attn_output.weight : tensor<4096x4096xf16> - %194 = torch_c.from_builtin_tensor %__auto.blk.21.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %195 = torch_c.from_builtin_tensor %__auto.blk.21.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.21.ffn_norm.weight = util.global.load @__auto.blk.21.ffn_norm.weight : tensor<4096xf32> - %195 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %196 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.21.ffn_gate.weight = util.global.load @__auto.blk.21.ffn_gate.weight : tensor<14336x4096xf16> - %196 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %197 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.21.ffn_up.weight = util.global.load @__auto.blk.21.ffn_up.weight : tensor<14336x4096xf16> - %197 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %198 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.21.ffn_down.weight = util.global.load @__auto.blk.21.ffn_down.weight : tensor<4096x14336xf16> - %198 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %199 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.22.attn_norm.weight = util.global.load @__auto.blk.22.attn_norm.weight : tensor<4096xf32> - %199 = torch_c.from_builtin_tensor %__auto.blk.22.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %200 = torch_c.from_builtin_tensor %__auto.blk.22.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.22.attn_q.weight = util.global.load @__auto.blk.22.attn_q.weight : tensor<4096x4096xf16> - %200 = torch_c.from_builtin_tensor %__auto.blk.22.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %201 = torch_c.from_builtin_tensor %__auto.blk.22.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.22.attn_k.weight = util.global.load @__auto.blk.22.attn_k.weight : tensor<1024x4096xf16> - %201 = torch_c.from_builtin_tensor %__auto.blk.22.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %202 = torch_c.from_builtin_tensor %__auto.blk.22.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.22.attn_v.weight = util.global.load @__auto.blk.22.attn_v.weight : tensor<1024x4096xf16> - %202 = torch_c.from_builtin_tensor %__auto.blk.22.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %203 = torch_c.from_builtin_tensor %__auto.blk.22.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.22.attn_output.weight = util.global.load @__auto.blk.22.attn_output.weight : tensor<4096x4096xf16> - %203 = torch_c.from_builtin_tensor %__auto.blk.22.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %204 = torch_c.from_builtin_tensor %__auto.blk.22.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.22.ffn_norm.weight = util.global.load @__auto.blk.22.ffn_norm.weight : tensor<4096xf32> - %204 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %205 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.22.ffn_gate.weight = util.global.load @__auto.blk.22.ffn_gate.weight : tensor<14336x4096xf16> - %205 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %206 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.22.ffn_up.weight = util.global.load @__auto.blk.22.ffn_up.weight : tensor<14336x4096xf16> - %206 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %207 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.22.ffn_down.weight = util.global.load @__auto.blk.22.ffn_down.weight : tensor<4096x14336xf16> - %207 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %208 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.23.attn_norm.weight = util.global.load @__auto.blk.23.attn_norm.weight : tensor<4096xf32> - %208 = torch_c.from_builtin_tensor %__auto.blk.23.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %209 = torch_c.from_builtin_tensor %__auto.blk.23.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.23.attn_q.weight = util.global.load @__auto.blk.23.attn_q.weight : tensor<4096x4096xf16> - %209 = torch_c.from_builtin_tensor %__auto.blk.23.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %210 = torch_c.from_builtin_tensor %__auto.blk.23.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.23.attn_k.weight = util.global.load @__auto.blk.23.attn_k.weight : tensor<1024x4096xf16> - %210 = torch_c.from_builtin_tensor %__auto.blk.23.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %211 = torch_c.from_builtin_tensor %__auto.blk.23.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.23.attn_v.weight = util.global.load @__auto.blk.23.attn_v.weight : tensor<1024x4096xf16> - %211 = torch_c.from_builtin_tensor %__auto.blk.23.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %212 = torch_c.from_builtin_tensor %__auto.blk.23.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.23.attn_output.weight = util.global.load @__auto.blk.23.attn_output.weight : tensor<4096x4096xf16> - %212 = torch_c.from_builtin_tensor %__auto.blk.23.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %213 = torch_c.from_builtin_tensor %__auto.blk.23.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.23.ffn_norm.weight = util.global.load @__auto.blk.23.ffn_norm.weight : tensor<4096xf32> - %213 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %214 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.23.ffn_gate.weight = util.global.load @__auto.blk.23.ffn_gate.weight : tensor<14336x4096xf16> - %214 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %215 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.23.ffn_up.weight = util.global.load @__auto.blk.23.ffn_up.weight : tensor<14336x4096xf16> - %215 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %216 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.23.ffn_down.weight = util.global.load @__auto.blk.23.ffn_down.weight : tensor<4096x14336xf16> - %216 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %217 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.24.attn_norm.weight = util.global.load @__auto.blk.24.attn_norm.weight : tensor<4096xf32> - %217 = torch_c.from_builtin_tensor %__auto.blk.24.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %218 = torch_c.from_builtin_tensor %__auto.blk.24.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.24.attn_q.weight = util.global.load @__auto.blk.24.attn_q.weight : tensor<4096x4096xf16> - %218 = torch_c.from_builtin_tensor %__auto.blk.24.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %219 = torch_c.from_builtin_tensor %__auto.blk.24.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.24.attn_k.weight = util.global.load @__auto.blk.24.attn_k.weight : tensor<1024x4096xf16> - %219 = torch_c.from_builtin_tensor %__auto.blk.24.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %220 = torch_c.from_builtin_tensor %__auto.blk.24.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.24.attn_v.weight = util.global.load @__auto.blk.24.attn_v.weight : tensor<1024x4096xf16> - %220 = torch_c.from_builtin_tensor %__auto.blk.24.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %221 = torch_c.from_builtin_tensor %__auto.blk.24.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.24.attn_output.weight = util.global.load @__auto.blk.24.attn_output.weight : tensor<4096x4096xf16> - %221 = torch_c.from_builtin_tensor %__auto.blk.24.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %222 = torch_c.from_builtin_tensor %__auto.blk.24.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.24.ffn_norm.weight = util.global.load @__auto.blk.24.ffn_norm.weight : tensor<4096xf32> - %222 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %223 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.24.ffn_gate.weight = util.global.load @__auto.blk.24.ffn_gate.weight : tensor<14336x4096xf16> - %223 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %224 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.24.ffn_up.weight = util.global.load @__auto.blk.24.ffn_up.weight : tensor<14336x4096xf16> - %224 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %225 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.24.ffn_down.weight = util.global.load @__auto.blk.24.ffn_down.weight : tensor<4096x14336xf16> - %225 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %226 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.25.attn_norm.weight = util.global.load @__auto.blk.25.attn_norm.weight : tensor<4096xf32> - %226 = torch_c.from_builtin_tensor %__auto.blk.25.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %227 = torch_c.from_builtin_tensor %__auto.blk.25.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.25.attn_q.weight = util.global.load @__auto.blk.25.attn_q.weight : tensor<4096x4096xf16> - %227 = torch_c.from_builtin_tensor %__auto.blk.25.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %228 = torch_c.from_builtin_tensor %__auto.blk.25.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.25.attn_k.weight = util.global.load @__auto.blk.25.attn_k.weight : tensor<1024x4096xf16> - %228 = torch_c.from_builtin_tensor %__auto.blk.25.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %229 = torch_c.from_builtin_tensor %__auto.blk.25.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.25.attn_v.weight = util.global.load @__auto.blk.25.attn_v.weight : tensor<1024x4096xf16> - %229 = torch_c.from_builtin_tensor %__auto.blk.25.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %230 = torch_c.from_builtin_tensor %__auto.blk.25.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.25.attn_output.weight = util.global.load @__auto.blk.25.attn_output.weight : tensor<4096x4096xf16> - %230 = torch_c.from_builtin_tensor %__auto.blk.25.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %231 = torch_c.from_builtin_tensor %__auto.blk.25.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.25.ffn_norm.weight = util.global.load @__auto.blk.25.ffn_norm.weight : tensor<4096xf32> - %231 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %232 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.25.ffn_gate.weight = util.global.load @__auto.blk.25.ffn_gate.weight : tensor<14336x4096xf16> - %232 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %233 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.25.ffn_up.weight = util.global.load @__auto.blk.25.ffn_up.weight : tensor<14336x4096xf16> - %233 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %234 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.25.ffn_down.weight = util.global.load @__auto.blk.25.ffn_down.weight : tensor<4096x14336xf16> - %234 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %235 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.26.attn_norm.weight = util.global.load @__auto.blk.26.attn_norm.weight : tensor<4096xf32> - %235 = torch_c.from_builtin_tensor %__auto.blk.26.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %236 = torch_c.from_builtin_tensor %__auto.blk.26.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.26.attn_q.weight = util.global.load @__auto.blk.26.attn_q.weight : tensor<4096x4096xf16> - %236 = torch_c.from_builtin_tensor %__auto.blk.26.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %237 = torch_c.from_builtin_tensor %__auto.blk.26.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.26.attn_k.weight = util.global.load @__auto.blk.26.attn_k.weight : tensor<1024x4096xf16> - %237 = torch_c.from_builtin_tensor %__auto.blk.26.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %238 = torch_c.from_builtin_tensor %__auto.blk.26.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.26.attn_v.weight = util.global.load @__auto.blk.26.attn_v.weight : tensor<1024x4096xf16> - %238 = torch_c.from_builtin_tensor %__auto.blk.26.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %239 = torch_c.from_builtin_tensor %__auto.blk.26.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.26.attn_output.weight = util.global.load @__auto.blk.26.attn_output.weight : tensor<4096x4096xf16> - %239 = torch_c.from_builtin_tensor %__auto.blk.26.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %240 = torch_c.from_builtin_tensor %__auto.blk.26.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.26.ffn_norm.weight = util.global.load @__auto.blk.26.ffn_norm.weight : tensor<4096xf32> - %240 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %241 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.26.ffn_gate.weight = util.global.load @__auto.blk.26.ffn_gate.weight : tensor<14336x4096xf16> - %241 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %242 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.26.ffn_up.weight = util.global.load @__auto.blk.26.ffn_up.weight : tensor<14336x4096xf16> - %242 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %243 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.26.ffn_down.weight = util.global.load @__auto.blk.26.ffn_down.weight : tensor<4096x14336xf16> - %243 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %244 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.27.attn_norm.weight = util.global.load @__auto.blk.27.attn_norm.weight : tensor<4096xf32> - %244 = torch_c.from_builtin_tensor %__auto.blk.27.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %245 = torch_c.from_builtin_tensor %__auto.blk.27.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.27.attn_q.weight = util.global.load @__auto.blk.27.attn_q.weight : tensor<4096x4096xf16> - %245 = torch_c.from_builtin_tensor %__auto.blk.27.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %246 = torch_c.from_builtin_tensor %__auto.blk.27.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.27.attn_k.weight = util.global.load @__auto.blk.27.attn_k.weight : tensor<1024x4096xf16> - %246 = torch_c.from_builtin_tensor %__auto.blk.27.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %247 = torch_c.from_builtin_tensor %__auto.blk.27.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.27.attn_v.weight = util.global.load @__auto.blk.27.attn_v.weight : tensor<1024x4096xf16> - %247 = torch_c.from_builtin_tensor %__auto.blk.27.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %248 = torch_c.from_builtin_tensor %__auto.blk.27.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.27.attn_output.weight = util.global.load @__auto.blk.27.attn_output.weight : tensor<4096x4096xf16> - %248 = torch_c.from_builtin_tensor %__auto.blk.27.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %249 = torch_c.from_builtin_tensor %__auto.blk.27.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.27.ffn_norm.weight = util.global.load @__auto.blk.27.ffn_norm.weight : tensor<4096xf32> - %249 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %250 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.27.ffn_gate.weight = util.global.load @__auto.blk.27.ffn_gate.weight : tensor<14336x4096xf16> - %250 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %251 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.27.ffn_up.weight = util.global.load @__auto.blk.27.ffn_up.weight : tensor<14336x4096xf16> - %251 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %252 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.27.ffn_down.weight = util.global.load @__auto.blk.27.ffn_down.weight : tensor<4096x14336xf16> - %252 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %253 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.28.attn_norm.weight = util.global.load @__auto.blk.28.attn_norm.weight : tensor<4096xf32> - %253 = torch_c.from_builtin_tensor %__auto.blk.28.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %254 = torch_c.from_builtin_tensor %__auto.blk.28.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.28.attn_q.weight = util.global.load @__auto.blk.28.attn_q.weight : tensor<4096x4096xf16> - %254 = torch_c.from_builtin_tensor %__auto.blk.28.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %255 = torch_c.from_builtin_tensor %__auto.blk.28.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.28.attn_k.weight = util.global.load @__auto.blk.28.attn_k.weight : tensor<1024x4096xf16> - %255 = torch_c.from_builtin_tensor %__auto.blk.28.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %256 = torch_c.from_builtin_tensor %__auto.blk.28.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.28.attn_v.weight = util.global.load @__auto.blk.28.attn_v.weight : tensor<1024x4096xf16> - %256 = torch_c.from_builtin_tensor %__auto.blk.28.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %257 = torch_c.from_builtin_tensor %__auto.blk.28.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.28.attn_output.weight = util.global.load @__auto.blk.28.attn_output.weight : tensor<4096x4096xf16> - %257 = torch_c.from_builtin_tensor %__auto.blk.28.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %258 = torch_c.from_builtin_tensor %__auto.blk.28.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.28.ffn_norm.weight = util.global.load @__auto.blk.28.ffn_norm.weight : tensor<4096xf32> - %258 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %259 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.28.ffn_gate.weight = util.global.load @__auto.blk.28.ffn_gate.weight : tensor<14336x4096xf16> - %259 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %260 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.28.ffn_up.weight = util.global.load @__auto.blk.28.ffn_up.weight : tensor<14336x4096xf16> - %260 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %261 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.28.ffn_down.weight = util.global.load @__auto.blk.28.ffn_down.weight : tensor<4096x14336xf16> - %261 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %262 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.29.attn_norm.weight = util.global.load @__auto.blk.29.attn_norm.weight : tensor<4096xf32> - %262 = torch_c.from_builtin_tensor %__auto.blk.29.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %263 = torch_c.from_builtin_tensor %__auto.blk.29.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.29.attn_q.weight = util.global.load @__auto.blk.29.attn_q.weight : tensor<4096x4096xf16> - %263 = torch_c.from_builtin_tensor %__auto.blk.29.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %264 = torch_c.from_builtin_tensor %__auto.blk.29.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.29.attn_k.weight = util.global.load @__auto.blk.29.attn_k.weight : tensor<1024x4096xf16> - %264 = torch_c.from_builtin_tensor %__auto.blk.29.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %265 = torch_c.from_builtin_tensor %__auto.blk.29.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.29.attn_v.weight = util.global.load @__auto.blk.29.attn_v.weight : tensor<1024x4096xf16> - %265 = torch_c.from_builtin_tensor %__auto.blk.29.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %266 = torch_c.from_builtin_tensor %__auto.blk.29.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.29.attn_output.weight = util.global.load @__auto.blk.29.attn_output.weight : tensor<4096x4096xf16> - %266 = torch_c.from_builtin_tensor %__auto.blk.29.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %267 = torch_c.from_builtin_tensor %__auto.blk.29.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.29.ffn_norm.weight = util.global.load @__auto.blk.29.ffn_norm.weight : tensor<4096xf32> - %267 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %268 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.29.ffn_gate.weight = util.global.load @__auto.blk.29.ffn_gate.weight : tensor<14336x4096xf16> - %268 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %269 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.29.ffn_up.weight = util.global.load @__auto.blk.29.ffn_up.weight : tensor<14336x4096xf16> - %269 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %270 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.29.ffn_down.weight = util.global.load @__auto.blk.29.ffn_down.weight : tensor<4096x14336xf16> - %270 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %271 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.30.attn_norm.weight = util.global.load @__auto.blk.30.attn_norm.weight : tensor<4096xf32> - %271 = torch_c.from_builtin_tensor %__auto.blk.30.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %272 = torch_c.from_builtin_tensor %__auto.blk.30.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.30.attn_q.weight = util.global.load @__auto.blk.30.attn_q.weight : tensor<4096x4096xf16> - %272 = torch_c.from_builtin_tensor %__auto.blk.30.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %273 = torch_c.from_builtin_tensor %__auto.blk.30.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.30.attn_k.weight = util.global.load @__auto.blk.30.attn_k.weight : tensor<1024x4096xf16> - %273 = torch_c.from_builtin_tensor %__auto.blk.30.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %274 = torch_c.from_builtin_tensor %__auto.blk.30.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.30.attn_v.weight = util.global.load @__auto.blk.30.attn_v.weight : tensor<1024x4096xf16> - %274 = torch_c.from_builtin_tensor %__auto.blk.30.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %275 = torch_c.from_builtin_tensor %__auto.blk.30.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.30.attn_output.weight = util.global.load @__auto.blk.30.attn_output.weight : tensor<4096x4096xf16> - %275 = torch_c.from_builtin_tensor %__auto.blk.30.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %276 = torch_c.from_builtin_tensor %__auto.blk.30.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.30.ffn_norm.weight = util.global.load @__auto.blk.30.ffn_norm.weight : tensor<4096xf32> - %276 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %277 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.30.ffn_gate.weight = util.global.load @__auto.blk.30.ffn_gate.weight : tensor<14336x4096xf16> - %277 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %278 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.30.ffn_up.weight = util.global.load @__auto.blk.30.ffn_up.weight : tensor<14336x4096xf16> - %278 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %279 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.30.ffn_down.weight = util.global.load @__auto.blk.30.ffn_down.weight : tensor<4096x14336xf16> - %279 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %280 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.31.attn_norm.weight = util.global.load @__auto.blk.31.attn_norm.weight : tensor<4096xf32> - %280 = torch_c.from_builtin_tensor %__auto.blk.31.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %281 = torch_c.from_builtin_tensor %__auto.blk.31.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.31.attn_q.weight = util.global.load @__auto.blk.31.attn_q.weight : tensor<4096x4096xf16> - %281 = torch_c.from_builtin_tensor %__auto.blk.31.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %282 = torch_c.from_builtin_tensor %__auto.blk.31.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.31.attn_k.weight = util.global.load @__auto.blk.31.attn_k.weight : tensor<1024x4096xf16> - %282 = torch_c.from_builtin_tensor %__auto.blk.31.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %283 = torch_c.from_builtin_tensor %__auto.blk.31.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.31.attn_v.weight = util.global.load @__auto.blk.31.attn_v.weight : tensor<1024x4096xf16> - %283 = torch_c.from_builtin_tensor %__auto.blk.31.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %284 = torch_c.from_builtin_tensor %__auto.blk.31.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.31.attn_output.weight = util.global.load @__auto.blk.31.attn_output.weight : tensor<4096x4096xf16> - %284 = torch_c.from_builtin_tensor %__auto.blk.31.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %285 = torch_c.from_builtin_tensor %__auto.blk.31.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.31.ffn_norm.weight = util.global.load @__auto.blk.31.ffn_norm.weight : tensor<4096xf32> - %285 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %286 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.31.ffn_gate.weight = util.global.load @__auto.blk.31.ffn_gate.weight : tensor<14336x4096xf16> - %286 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %287 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.31.ffn_up.weight = util.global.load @__auto.blk.31.ffn_up.weight : tensor<14336x4096xf16> - %287 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %288 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.31.ffn_down.weight = util.global.load @__auto.blk.31.ffn_down.weight : tensor<4096x14336xf16> - %288 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %289 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.output_norm.weight = util.global.load @__auto.output_norm.weight : tensor<4096xf32> - %289 = torch_c.from_builtin_tensor %__auto.output_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %290 = torch_c.from_builtin_tensor %__auto.output_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.output.weight = util.global.load @__auto.output.weight : tensor<128256x4096xf16> - %290 = torch_c.from_builtin_tensor %__auto.output.weight : tensor<128256x4096xf16> -> !torch.vtensor<[128256,4096],f16> - %291 = torch.copy.to_vtensor %arg3 : !torch.vtensor<[?,2097152],f16> - %292 = torch.symbolic_int "s1" {min_val = 2, max_val = 4095} : !torch.int - %293 = torch.symbolic_int "s2" {min_val = 2, max_val = 9223372036854775806} : !torch.int - torch.bind_symbolic_shape %arg0, [%292], affine_map<()[s0] -> (4, s0 * 32)> : !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %arg2, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %291, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-1 = torch.constant.int -1 + %291 = torch_c.from_builtin_tensor %__auto.output.weight : tensor<128256x4096xf16> -> !torch.vtensor<[128256,4096],f16> + %292 = torch.copy.to_vtensor %arg3 : !torch.vtensor<[?,2097152],f16> + %293 = torch.symbolic_int "32*s1" {min_val = 64, max_val = 131040} : !torch.int + %294 = torch.symbolic_int "s1" {min_val = 2, max_val = 4095} : !torch.int + %295 = torch.symbolic_int "s2" {min_val = 0, max_val = 9223372036854775807} : !torch.int + torch.bind_symbolic_shape %arg0, [%294], affine_map<()[s0] -> (4, s0 * 32)> : !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %arg2, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %292, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int1 = torch.constant.int 1 + %296 = torch.aten.size.int %arg2, %int1 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.int + %int0 = torch.constant.int 0 + %297 = torch.aten.size.int %292, %int0 : !torch.vtensor<[?,2097152],f16>, !torch.int -> !torch.int + %int1_0 = torch.constant.int 1 + %298 = torch.aten.size.int %arg0, %int1_0 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.int + %int0_1 = torch.constant.int 0 + %int1_2 = torch.constant.int 1 + %none = torch.constant.none + %none_3 = torch.constant.none + %cpu = torch.constant.device "cpu" %false = torch.constant.bool false - %false_0 = torch.constant.bool false - %294 = torch.aten.embedding %0, %arg0, %int-1, %false, %false_0 : !torch.vtensor<[128256,4096],f16>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %294, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6 = torch.constant.int 6 - %295 = torch.prims.convert_element_type %294, %int6 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %295, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %299 = torch.aten.arange.start_step %int0_1, %298, %int1_2, %none, %none_3, %cpu, %false : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %299, [%294], affine_map<()[s0] -> (s0 * 32)> : !torch.vtensor<[?],si64> + %int-1 = torch.constant.int -1 + %300 = torch.aten.unsqueeze %arg1, %int-1 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %301 = torch.aten.ge.Tensor %299, %300 : !torch.vtensor<[?],si64>, !torch.vtensor<[4,1],si64> -> !torch.vtensor<[4,?],i1> + torch.bind_symbolic_shape %301, [%294], affine_map<()[s0] -> (4, s0 * 32)> : !torch.vtensor<[4,?],i1> + %int1_4 = torch.constant.int 1 + %int1_5 = torch.constant.int 1 + %302 = torch.prim.ListConstruct %int1_4, %int1_5 : (!torch.int, !torch.int) -> !torch.list + %int11 = torch.constant.int 11 + %none_6 = torch.constant.none + %cpu_7 = torch.constant.device "cpu" + %false_8 = torch.constant.bool false + %303 = torch.aten.ones %302, %int11, %none_6, %cpu_7, %false_8 : !torch.list, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1,1],i1> + %int131072 = torch.constant.int 131072 + %int131072_9 = torch.constant.int 131072 + %304 = torch.prim.ListConstruct %int131072, %int131072_9 : (!torch.int, !torch.int) -> !torch.list + %false_10 = torch.constant.bool false + %305 = torch.aten.expand %303, %304, %false_10 : !torch.vtensor<[1,1],i1>, !torch.list, !torch.bool -> !torch.vtensor<[131072,131072],i1> + %int1_11 = torch.constant.int 1 + %306 = torch.aten.triu %305, %int1_11 : !torch.vtensor<[131072,131072],i1>, !torch.int -> !torch.vtensor<[131072,131072],i1> + %int0_12 = torch.constant.int 0 + %307 = torch.aten.unsqueeze %306, %int0_12 : !torch.vtensor<[131072,131072],i1>, !torch.int -> !torch.vtensor<[1,131072,131072],i1> + %int1_13 = torch.constant.int 1 + %308 = torch.aten.unsqueeze %307, %int1_13 : !torch.vtensor<[1,131072,131072],i1>, !torch.int -> !torch.vtensor<[1,1,131072,131072],i1> %int2 = torch.constant.int 2 - %296 = torch.aten.pow.Tensor_Scalar %295, %int2 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %296, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_1 = torch.constant.int -1 - %297 = torch.prim.ListConstruct %int-1_1 : (!torch.int) -> !torch.list + %int0_14 = torch.constant.int 0 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %int1_15 = torch.constant.int 1 + %309 = torch.aten.slice.Tensor %308, %int2, %int0_14, %int9223372036854775807, %int1_15 : !torch.vtensor<[1,1,131072,131072],i1>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,131072,131072],i1> + %int3 = torch.constant.int 3 + %int0_16 = torch.constant.int 0 + %int9223372036854775807_17 = torch.constant.int 9223372036854775807 + %int1_18 = torch.constant.int 1 + %310 = torch.aten.slice.Tensor %309, %int3, %int0_16, %int9223372036854775807_17, %int1_18 : !torch.vtensor<[1,1,131072,131072],i1>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,131072,131072],i1> + %int0_19 = torch.constant.int 0 + %int0_20 = torch.constant.int 0 + %int9223372036854775807_21 = torch.constant.int 9223372036854775807 + %int1_22 = torch.constant.int 1 + %311 = torch.aten.slice.Tensor %310, %int0_19, %int0_20, %int9223372036854775807_21, %int1_22 : !torch.vtensor<[1,1,131072,131072],i1>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,131072,131072],i1> + %int1_23 = torch.constant.int 1 + %int0_24 = torch.constant.int 0 + %int9223372036854775807_25 = torch.constant.int 9223372036854775807 + %int1_26 = torch.constant.int 1 + %312 = torch.aten.slice.Tensor %311, %int1_23, %int0_24, %int9223372036854775807_25, %int1_26 : !torch.vtensor<[1,1,131072,131072],i1>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,131072,131072],i1> + %int2_27 = torch.constant.int 2 + %int0_28 = torch.constant.int 0 + %int1_29 = torch.constant.int 1 + %313 = torch.aten.slice.Tensor %312, %int2_27, %int0_28, %298, %int1_29 : !torch.vtensor<[1,1,131072,131072],i1>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,?,131072],i1> + torch.bind_symbolic_shape %313, [%294], affine_map<()[s0] -> (1, 1, s0 * 32, 131072)> : !torch.vtensor<[1,1,?,131072],i1> + %int3_30 = torch.constant.int 3 + %int0_31 = torch.constant.int 0 + %int1_32 = torch.constant.int 1 + %314 = torch.aten.slice.Tensor %313, %int3_30, %int0_31, %298, %int1_32 : !torch.vtensor<[1,1,?,131072],i1>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,?,?],i1> + torch.bind_symbolic_shape %314, [%294], affine_map<()[s0] -> (1, 1, s0 * 32, s0 * 32)> : !torch.vtensor<[1,1,?,?],i1> + %int0_33 = torch.constant.int 0 + %int0_34 = torch.constant.int 0 + %int9223372036854775807_35 = torch.constant.int 9223372036854775807 + %int1_36 = torch.constant.int 1 + %315 = torch.aten.slice.Tensor %301, %int0_33, %int0_34, %int9223372036854775807_35, %int1_36 : !torch.vtensor<[4,?],i1>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?],i1> + torch.bind_symbolic_shape %315, [%294], affine_map<()[s0] -> (4, s0 * 32)> : !torch.vtensor<[4,?],i1> + %int1_37 = torch.constant.int 1 + %316 = torch.aten.unsqueeze %315, %int1_37 : !torch.vtensor<[4,?],i1>, !torch.int -> !torch.vtensor<[4,1,?],i1> + torch.bind_symbolic_shape %316, [%294], affine_map<()[s0] -> (4, 1, s0 * 32)> : !torch.vtensor<[4,1,?],i1> + %int2_38 = torch.constant.int 2 + %317 = torch.aten.unsqueeze %316, %int2_38 : !torch.vtensor<[4,1,?],i1>, !torch.int -> !torch.vtensor<[4,1,1,?],i1> + torch.bind_symbolic_shape %317, [%294], affine_map<()[s0] -> (4, 1, 1, s0 * 32)> : !torch.vtensor<[4,1,1,?],i1> + %int3_39 = torch.constant.int 3 + %int0_40 = torch.constant.int 0 + %int9223372036854775807_41 = torch.constant.int 9223372036854775807 + %int1_42 = torch.constant.int 1 + %318 = torch.aten.slice.Tensor %317, %int3_39, %int0_40, %int9223372036854775807_41, %int1_42 : !torch.vtensor<[4,1,1,?],i1>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1,1,?],i1> + torch.bind_symbolic_shape %318, [%294], affine_map<()[s0] -> (4, 1, 1, s0 * 32)> : !torch.vtensor<[4,1,1,?],i1> + %319 = torch.aten.logical_or %314, %318 : !torch.vtensor<[1,1,?,?],i1>, !torch.vtensor<[4,1,1,?],i1> -> !torch.vtensor<[4,1,?,?],i1> + torch.bind_symbolic_shape %319, [%294], affine_map<()[s0] -> (4, 1, s0 * 32, s0 * 32)> : !torch.vtensor<[4,1,?,?],i1> + %none_43 = torch.constant.none + %320 = torch.aten.clone %0, %none_43 : !torch.vtensor<[],f16>, !torch.none -> !torch.vtensor<[],f16> + %321 = torch.aten.detach %320 : !torch.vtensor<[],f16> -> !torch.vtensor<[],f16> + %322 = torch.aten.detach %321 : !torch.vtensor<[],f16> -> !torch.vtensor<[],f16> + %323 = torch.aten.detach %322 : !torch.vtensor<[],f16> -> !torch.vtensor<[],f16> + %int0_44 = torch.constant.int 0 + %int5 = torch.constant.int 5 + %int0_45 = torch.constant.int 0 + %cpu_46 = torch.constant.device "cpu" + %none_47 = torch.constant.none + %324 = torch.aten.scalar_tensor %int0_44, %int5, %int0_45, %cpu_46, %none_47 : !torch.int, !torch.int, !torch.int, !torch.Device, !torch.none -> !torch.vtensor<[],f16> + %325 = torch.aten.where.self %319, %323, %324 : !torch.vtensor<[4,1,?,?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[],f16> -> !torch.vtensor<[4,1,?,?],f16> + torch.bind_symbolic_shape %325, [%294], affine_map<()[s0] -> (4, 1, s0 * 32, s0 * 32)> : !torch.vtensor<[4,1,?,?],f16> + %int5_48 = torch.constant.int 5 + %326 = torch.prims.convert_element_type %325, %int5_48 : !torch.vtensor<[4,1,?,?],f16>, !torch.int -> !torch.vtensor<[4,1,?,?],f16> + torch.bind_symbolic_shape %326, [%294], affine_map<()[s0] -> (4, 1, s0 * 32, s0 * 32)> : !torch.vtensor<[4,1,?,?],f16> + %int5_49 = torch.constant.int 5 + %327 = torch.prims.convert_element_type %326, %int5_49 : !torch.vtensor<[4,1,?,?],f16>, !torch.int -> !torch.vtensor<[4,1,?,?],f16> + torch.bind_symbolic_shape %327, [%294], affine_map<()[s0] -> (4, 1, s0 * 32, s0 * 32)> : !torch.vtensor<[4,1,?,?],f16> + %int5_50 = torch.constant.int 5 + %328 = torch.prims.convert_element_type %1, %int5_50 : !torch.vtensor<[128256,4096],f16>, !torch.int -> !torch.vtensor<[128256,4096],f16> + %int-1_51 = torch.constant.int -1 + %false_52 = torch.constant.bool false + %false_53 = torch.constant.bool false + %329 = torch.aten.embedding %328, %arg0, %int-1_51, %false_52, %false_53 : !torch.vtensor<[128256,4096],f16>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %329, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6 = torch.constant.int 6 + %330 = torch.prims.convert_element_type %329, %int6 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %330, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_54 = torch.constant.int 2 + %331 = torch.aten.pow.Tensor_Scalar %330, %int2_54 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %331, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_55 = torch.constant.int -1 + %332 = torch.prim.ListConstruct %int-1_55 : (!torch.int) -> !torch.list %true = torch.constant.bool true - %none = torch.constant.none - %298 = torch.aten.mean.dim %296, %297, %true, %none : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %298, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %none_56 = torch.constant.none + %333 = torch.aten.mean.dim %331, %332, %true, %none_56 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %333, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> %float9.999990e-06 = torch.constant.float 9.9999997473787516E-6 - %int1 = torch.constant.int 1 - %299 = torch.aten.add.Scalar %298, %float9.999990e-06, %int1 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %299, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %300 = torch.aten.rsqrt %299 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %300, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %301 = torch.aten.mul.Tensor %295, %300 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %301, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5 = torch.constant.int 5 - %302 = torch.prims.convert_element_type %301, %int5 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %302, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %303 = torch.aten.mul.Tensor %1, %302 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %303, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2 = torch.constant.int 5 - %304 = torch.prims.convert_element_type %303, %int5_2 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %304, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_57 = torch.constant.int 1 + %334 = torch.aten.add.Scalar %333, %float9.999990e-06, %int1_57 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %334, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %335 = torch.aten.rsqrt %334 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %335, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %336 = torch.aten.mul.Tensor %330, %335 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %336, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_58 = torch.constant.int 5 + %337 = torch.prims.convert_element_type %336, %int5_58 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %337, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %338 = torch.aten.mul.Tensor %2, %337 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %338, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_59 = torch.constant.int 5 + %339 = torch.prims.convert_element_type %338, %int5_59 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %339, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> %int-2 = torch.constant.int -2 - %int-1_3 = torch.constant.int -1 - %305 = torch.aten.transpose.int %2, %int-2, %int-1_3 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int1_4 = torch.constant.int 1 - %306 = torch.aten.size.int %arg0, %int1_4 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.int + %int-1_60 = torch.constant.int -1 + %340 = torch.aten.transpose.int %3, %int-2, %int-1_60 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_61 = torch.constant.int 5 + %341 = torch.prims.convert_element_type %340, %int5_61 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> %int4 = torch.constant.int 4 - %307 = torch.aten.mul.int %int4, %306 : !torch.int, !torch.int -> !torch.int + %342 = torch.aten.mul.int %int4, %298 : !torch.int, !torch.int -> !torch.int %int4096 = torch.constant.int 4096 - %308 = torch.prim.ListConstruct %307, %int4096 : (!torch.int, !torch.int) -> !torch.list - %309 = torch.aten.view %304, %308 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %309, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %310 = torch.aten.mm %309, %305 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %310, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_5 = torch.constant.int 4 - %int4096_6 = torch.constant.int 4096 - %311 = torch.prim.ListConstruct %int4_5, %306, %int4096_6 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %312 = torch.aten.view %310, %311 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %312, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7 = torch.constant.int -2 - %int-1_8 = torch.constant.int -1 - %313 = torch.aten.transpose.int %3, %int-2_7, %int-1_8 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_9 = torch.constant.int 4 - %314 = torch.aten.mul.int %int4_9, %306 : !torch.int, !torch.int -> !torch.int - %int4096_10 = torch.constant.int 4096 - %315 = torch.prim.ListConstruct %314, %int4096_10 : (!torch.int, !torch.int) -> !torch.list - %316 = torch.aten.view %304, %315 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %316, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %317 = torch.aten.mm %316, %313 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %317, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_11 = torch.constant.int 4 + %343 = torch.prim.ListConstruct %342, %int4096 : (!torch.int, !torch.int) -> !torch.list + %344 = torch.aten.view %339, %343 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %344, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %345 = torch.aten.mm %344, %341 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %345, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_62 = torch.constant.int 4 + %int4096_63 = torch.constant.int 4096 + %346 = torch.prim.ListConstruct %int4_62, %298, %int4096_63 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %347 = torch.aten.view %345, %346 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %347, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_64 = torch.constant.int -2 + %int-1_65 = torch.constant.int -1 + %348 = torch.aten.transpose.int %4, %int-2_64, %int-1_65 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_66 = torch.constant.int 5 + %349 = torch.prims.convert_element_type %348, %int5_66 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_67 = torch.constant.int 4096 + %350 = torch.prim.ListConstruct %342, %int4096_67 : (!torch.int, !torch.int) -> !torch.list + %351 = torch.aten.view %339, %350 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %351, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %352 = torch.aten.mm %351, %349 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %352, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_68 = torch.constant.int 4 %int1024 = torch.constant.int 1024 - %318 = torch.prim.ListConstruct %int4_11, %306, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %319 = torch.aten.view %317, %318 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %319, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_12 = torch.constant.int -2 - %int-1_13 = torch.constant.int -1 - %320 = torch.aten.transpose.int %4, %int-2_12, %int-1_13 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_14 = torch.constant.int 4 - %321 = torch.aten.mul.int %int4_14, %306 : !torch.int, !torch.int -> !torch.int - %int4096_15 = torch.constant.int 4096 - %322 = torch.prim.ListConstruct %321, %int4096_15 : (!torch.int, !torch.int) -> !torch.list - %323 = torch.aten.view %304, %322 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %323, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %324 = torch.aten.mm %323, %320 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %324, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_16 = torch.constant.int 4 - %int1024_17 = torch.constant.int 1024 - %325 = torch.prim.ListConstruct %int4_16, %306, %int1024_17 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %326 = torch.aten.view %324, %325 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %326, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_18 = torch.constant.int 4 + %353 = torch.prim.ListConstruct %int4_68, %298, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %354 = torch.aten.view %352, %353 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %354, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_69 = torch.constant.int -2 + %int-1_70 = torch.constant.int -1 + %355 = torch.aten.transpose.int %5, %int-2_69, %int-1_70 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_71 = torch.constant.int 5 + %356 = torch.prims.convert_element_type %355, %int5_71 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_72 = torch.constant.int 4096 + %357 = torch.prim.ListConstruct %342, %int4096_72 : (!torch.int, !torch.int) -> !torch.list + %358 = torch.aten.view %339, %357 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %358, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %359 = torch.aten.mm %358, %356 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %359, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_73 = torch.constant.int 4 + %int1024_74 = torch.constant.int 1024 + %360 = torch.prim.ListConstruct %int4_73, %298, %int1024_74 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %361 = torch.aten.view %359, %360 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %361, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_75 = torch.constant.int 4 %int32 = torch.constant.int 32 %int128 = torch.constant.int 128 - %327 = torch.prim.ListConstruct %int4_18, %306, %int32, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %328 = torch.aten.view %312, %327 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %328, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_19 = torch.constant.int 4 + %362 = torch.prim.ListConstruct %int4_75, %298, %int32, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %363 = torch.aten.view %347, %362 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %363, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_76 = torch.constant.int 4 %int8 = torch.constant.int 8 - %int128_20 = torch.constant.int 128 - %329 = torch.prim.ListConstruct %int4_19, %306, %int8, %int128_20 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %330 = torch.aten.view %319, %329 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %330, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_21 = torch.constant.int 4 - %int8_22 = torch.constant.int 8 - %int128_23 = torch.constant.int 128 - %331 = torch.prim.ListConstruct %int4_21, %306, %int8_22, %int128_23 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %332 = torch.aten.view %326, %331 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %332, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072 = torch.constant.int 131072 - %none_24 = torch.constant.none - %none_25 = torch.constant.none - %cpu = torch.constant.device "cpu" - %false_26 = torch.constant.bool false - %333 = torch.aten.arange %int131072, %none_24, %none_25, %cpu, %false_26 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0 = torch.constant.int 0 - %int128_27 = torch.constant.int 128 - %none_28 = torch.constant.none - %none_29 = torch.constant.none - %cpu_30 = torch.constant.device "cpu" - %false_31 = torch.constant.bool false - %334 = torch.aten.arange.start %int0, %int128_27, %none_28, %none_29, %cpu_30, %false_31 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_32 = torch.constant.int 2 - %335 = torch.aten.floor_divide.Scalar %334, %int2_32 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_33 = torch.constant.int 6 - %336 = torch.prims.convert_element_type %335, %int6_33 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_34 = torch.constant.int 128 - %337 = torch.aten.div.Scalar %336, %int128_34 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00 = torch.constant.float 2.000000e+00 - %338 = torch.aten.mul.Scalar %337, %float2.000000e00 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %int128_77 = torch.constant.int 128 + %364 = torch.prim.ListConstruct %int4_76, %298, %int8, %int128_77 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %365 = torch.aten.view %354, %364 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %365, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_78 = torch.constant.int 4 + %int8_79 = torch.constant.int 8 + %int128_80 = torch.constant.int 128 + %366 = torch.prim.ListConstruct %int4_78, %298, %int8_79, %int128_80 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %367 = torch.aten.view %361, %366 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %367, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_81 = torch.constant.int 131072 + %none_82 = torch.constant.none + %none_83 = torch.constant.none + %cpu_84 = torch.constant.device "cpu" + %false_85 = torch.constant.bool false + %368 = torch.aten.arange %int131072_81, %none_82, %none_83, %cpu_84, %false_85 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_86 = torch.constant.int 0 + %int128_87 = torch.constant.int 128 + %int2_88 = torch.constant.int 2 + %int4_89 = torch.constant.int 4 + %none_90 = torch.constant.none + %cpu_91 = torch.constant.device "cpu" + %false_92 = torch.constant.bool false + %369 = torch.aten.arange.start_step %int0_86, %int128_87, %int2_88, %int4_89, %none_90, %cpu_91, %false_92 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_93 = torch.constant.int 6 + %370 = torch.prims.convert_element_type %369, %int6_93 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_94 = torch.constant.int 128 + %371 = torch.aten.div.Scalar %370, %int128_94 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %float5.000000e05 = torch.constant.float 5.000000e+05 - %339 = torch.aten.pow.Scalar %float5.000000e05, %338 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %340 = torch.aten.reciprocal %339 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> + %372 = torch.aten.pow.Scalar %float5.000000e05, %371 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %373 = torch.aten.reciprocal %372 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> %float1.000000e00 = torch.constant.float 1.000000e+00 - %341 = torch.aten.mul.Scalar %340, %float1.000000e00 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_35 = torch.constant.int 1 - %342 = torch.aten.unsqueeze %333, %int1_35 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_36 = torch.constant.int 0 - %343 = torch.aten.unsqueeze %341, %int0_36 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %344 = torch.aten.mul.Tensor %342, %343 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_37 = torch.constant.int 1 - %345 = torch.aten.size.int %312, %int1_37 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_38 = torch.constant.int 0 - %346 = torch.aten.add.int %int0_38, %345 : !torch.int, !torch.int -> !torch.int - %int0_39 = torch.constant.int 0 - %int0_40 = torch.constant.int 0 - %int1_41 = torch.constant.int 1 - %347 = torch.aten.slice.Tensor %344, %int0_39, %int0_40, %346, %int1_41 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %347, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_42 = torch.constant.int 1 - %int0_43 = torch.constant.int 0 - %int9223372036854775807 = torch.constant.int 9223372036854775807 - %int1_44 = torch.constant.int 1 - %348 = torch.aten.slice.Tensor %347, %int1_42, %int0_43, %int9223372036854775807, %int1_44 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %348, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_45 = torch.constant.int 1 - %int0_46 = torch.constant.int 0 - %int9223372036854775807_47 = torch.constant.int 9223372036854775807 - %int1_48 = torch.constant.int 1 - %349 = torch.aten.slice.Tensor %348, %int1_45, %int0_46, %int9223372036854775807_47, %int1_48 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %349, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_49 = torch.constant.int 0 - %350 = torch.aten.unsqueeze %349, %int0_49 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %350, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_50 = torch.constant.int 1 - %int0_51 = torch.constant.int 0 - %int9223372036854775807_52 = torch.constant.int 9223372036854775807 - %int1_53 = torch.constant.int 1 - %351 = torch.aten.slice.Tensor %350, %int1_50, %int0_51, %int9223372036854775807_52, %int1_53 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %351, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_54 = torch.constant.int 2 - %int0_55 = torch.constant.int 0 - %int9223372036854775807_56 = torch.constant.int 9223372036854775807 - %int1_57 = torch.constant.int 1 - %352 = torch.aten.slice.Tensor %351, %int2_54, %int0_55, %int9223372036854775807_56, %int1_57 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %352, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_58 = torch.constant.int 4 - %int1_59 = torch.constant.int 1 - %int1_60 = torch.constant.int 1 - %353 = torch.prim.ListConstruct %int4_58, %int1_59, %int1_60 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %354 = torch.aten.repeat %352, %353 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %354, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_61 = torch.constant.int 6 - %355 = torch.prims.convert_element_type %328, %int6_61 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %355, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %356 = torch_c.to_builtin_tensor %355 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %357 = torch_c.to_builtin_tensor %354 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %358 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%356, %357) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %359 = torch_c.from_builtin_tensor %358 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %359, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_62 = torch.constant.int 5 - %360 = torch.prims.convert_element_type %359, %int5_62 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %360, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_63 = torch.constant.int 131072 - %none_64 = torch.constant.none - %none_65 = torch.constant.none - %cpu_66 = torch.constant.device "cpu" - %false_67 = torch.constant.bool false - %361 = torch.aten.arange %int131072_63, %none_64, %none_65, %cpu_66, %false_67 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_68 = torch.constant.int 0 - %int128_69 = torch.constant.int 128 - %none_70 = torch.constant.none - %none_71 = torch.constant.none - %cpu_72 = torch.constant.device "cpu" - %false_73 = torch.constant.bool false - %362 = torch.aten.arange.start %int0_68, %int128_69, %none_70, %none_71, %cpu_72, %false_73 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_74 = torch.constant.int 2 - %363 = torch.aten.floor_divide.Scalar %362, %int2_74 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_75 = torch.constant.int 6 - %364 = torch.prims.convert_element_type %363, %int6_75 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_76 = torch.constant.int 128 - %365 = torch.aten.div.Scalar %364, %int128_76 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_77 = torch.constant.float 2.000000e+00 - %366 = torch.aten.mul.Scalar %365, %float2.000000e00_77 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_78 = torch.constant.float 5.000000e+05 - %367 = torch.aten.pow.Scalar %float5.000000e05_78, %366 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %368 = torch.aten.reciprocal %367 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_79 = torch.constant.float 1.000000e+00 - %369 = torch.aten.mul.Scalar %368, %float1.000000e00_79 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_80 = torch.constant.int 1 - %370 = torch.aten.unsqueeze %361, %int1_80 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_81 = torch.constant.int 0 - %371 = torch.aten.unsqueeze %369, %int0_81 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %372 = torch.aten.mul.Tensor %370, %371 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_82 = torch.constant.int 1 - %373 = torch.aten.size.int %319, %int1_82 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_83 = torch.constant.int 0 - %374 = torch.aten.add.int %int0_83, %373 : !torch.int, !torch.int -> !torch.int - %int0_84 = torch.constant.int 0 - %int0_85 = torch.constant.int 0 - %int1_86 = torch.constant.int 1 - %375 = torch.aten.slice.Tensor %372, %int0_84, %int0_85, %374, %int1_86 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %375, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_87 = torch.constant.int 1 - %int0_88 = torch.constant.int 0 - %int9223372036854775807_89 = torch.constant.int 9223372036854775807 - %int1_90 = torch.constant.int 1 - %376 = torch.aten.slice.Tensor %375, %int1_87, %int0_88, %int9223372036854775807_89, %int1_90 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %376, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_91 = torch.constant.int 1 - %int0_92 = torch.constant.int 0 - %int9223372036854775807_93 = torch.constant.int 9223372036854775807 - %int1_94 = torch.constant.int 1 - %377 = torch.aten.slice.Tensor %376, %int1_91, %int0_92, %int9223372036854775807_93, %int1_94 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %377, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_95 = torch.constant.int 0 - %378 = torch.aten.unsqueeze %377, %int0_95 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %378, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %374 = torch.aten.mul.Scalar %373, %float1.000000e00 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %375 = torch.aten.reciprocal %374 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00 = torch.constant.float 6.2831853071795862 + %376 = torch.aten.mul.Scalar %375, %float6.283190e00 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03 = torch.constant.float 8.192000e+03 + %377 = torch.aten.gt.Scalar %376, %float8.192000e03 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_95 = torch.constant.int 8 + %378 = torch.aten.div.Scalar %374, %int8_95 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %379 = torch.aten.where.self %377, %378, %374 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %380 = torch.aten.reciprocal %376 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192 = torch.constant.int 8192 + %381 = torch.aten.mul.Scalar %380, %int8192 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_96 = torch.constant.int 1 - %int0_97 = torch.constant.int 0 - %int9223372036854775807_98 = torch.constant.int 9223372036854775807 + %int1_97 = torch.constant.int 1 + %382 = torch.aten.sub.Scalar %381, %int1_96, %int1_97 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_98 = torch.constant.int 3 + %383 = torch.aten.div.Scalar %382, %int3_98 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_99 = torch.constant.int 1 - %379 = torch.aten.slice.Tensor %378, %int1_96, %int0_97, %int9223372036854775807_98, %int1_99 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %379, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_100 = torch.constant.int 2 - %int0_101 = torch.constant.int 0 - %int9223372036854775807_102 = torch.constant.int 9223372036854775807 - %int1_103 = torch.constant.int 1 - %380 = torch.aten.slice.Tensor %379, %int2_100, %int0_101, %int9223372036854775807_102, %int1_103 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %380, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_104 = torch.constant.int 4 - %int1_105 = torch.constant.int 1 + %int1_100 = torch.constant.int 1 + %384 = torch.aten.rsub.Scalar %383, %int1_99, %int1_100 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %385 = torch.aten.mul.Tensor %384, %379 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_101 = torch.constant.int 8 + %386 = torch.aten.div.Scalar %385, %int8_101 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %387 = torch.aten.mul.Tensor %383, %379 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_102 = torch.constant.int 1 + %388 = torch.aten.add.Tensor %386, %387, %int1_102 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03 = torch.constant.float 2.048000e+03 + %389 = torch.aten.lt.Scalar %376, %float2.048000e03 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %390 = torch.aten.bitwise_not %389 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_103 = torch.constant.float 8.192000e+03 + %391 = torch.aten.gt.Scalar %376, %float8.192000e03_103 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %392 = torch.aten.bitwise_not %391 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %393 = torch.aten.mul.Tensor %390, %392 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %394 = torch.aten.where.self %393, %388, %379 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %395 = torch.prim.ListConstruct %394, %394 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_104 = torch.constant.int -1 + %396 = torch.aten.cat %395, %int-1_104 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_105 = torch.constant.int 6 + %397 = torch.prims.convert_element_type %396, %int6_105 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> %int1_106 = torch.constant.int 1 - %381 = torch.prim.ListConstruct %int4_104, %int1_105, %int1_106 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %382 = torch.aten.repeat %380, %381 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %382, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> + %398 = torch.aten.unsqueeze %368, %int1_106 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> %int6_107 = torch.constant.int 6 - %383 = torch.prims.convert_element_type %330, %int6_107 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %383, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %384 = torch_c.to_builtin_tensor %383 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %385 = torch_c.to_builtin_tensor %382 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %386 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%384, %385) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %387 = torch_c.from_builtin_tensor %386 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %387, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_108 = torch.constant.int 5 - %388 = torch.prims.convert_element_type %387, %int5_108 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %388, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_109 = torch.constant.int 0 - %389 = torch.aten.size.int %291, %int0_109 : !torch.vtensor<[?,2097152],f16>, !torch.int -> !torch.int - %int32_110 = torch.constant.int 32 - %int2_111 = torch.constant.int 2 - %int32_112 = torch.constant.int 32 - %int8_113 = torch.constant.int 8 - %int128_114 = torch.constant.int 128 - %390 = torch.prim.ListConstruct %389, %int32_110, %int2_111, %int32_112, %int8_113, %int128_114 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %391 = torch.aten.view %291, %390 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %391, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_115 = torch.constant.int 32 - %392 = torch.aten.mul.int %389, %int32_115 : !torch.int, !torch.int -> !torch.int - %int2_116 = torch.constant.int 2 - %393 = torch.aten.mul.int %392, %int2_116 : !torch.int, !torch.int -> !torch.int - %int32_117 = torch.constant.int 32 - %int8_118 = torch.constant.int 8 - %int128_119 = torch.constant.int 128 - %394 = torch.prim.ListConstruct %393, %int32_117, %int8_118, %int128_119 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %395 = torch.aten.view %391, %394 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %395, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int64 = torch.constant.int 64 - %396 = torch.aten.mul.Scalar %arg2, %int64 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %396, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %399 = torch.prims.convert_element_type %398, %int6_107 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_108 = torch.constant.int 0 + %400 = torch.aten.unsqueeze %397, %int0_108 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_109 = torch.constant.int 6 + %401 = torch.prims.convert_element_type %400, %int6_109 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %402 = torch.aten.mul.Tensor %399, %401 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %403 = torch.aten.cos %402 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_110 = torch.constant.int 5 + %404 = torch.prims.convert_element_type %403, %int5_110 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %405 = torch.aten.sin %402 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_111 = torch.constant.int 5 + %406 = torch.prims.convert_element_type %405, %int5_111 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_112 = torch.constant.int 0 + %int0_113 = torch.constant.int 0 + %int1_114 = torch.constant.int 1 + %407 = torch.aten.slice.Tensor %404, %int0_112, %int0_113, %298, %int1_114 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %407, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_115 = torch.constant.int 1 + %int0_116 = torch.constant.int 0 + %int9223372036854775807_117 = torch.constant.int 9223372036854775807 + %int1_118 = torch.constant.int 1 + %408 = torch.aten.slice.Tensor %407, %int1_115, %int0_116, %int9223372036854775807_117, %int1_118 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %408, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_119 = torch.constant.int 0 %int0_120 = torch.constant.int 0 %int1_121 = torch.constant.int 1 - %397 = torch.aten.add.Scalar %396, %int0_120, %int1_121 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %397, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %409 = torch.aten.slice.Tensor %406, %int0_119, %int0_120, %298, %int1_121 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %409, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_122 = torch.constant.int 1 - %398 = torch.aten.size.int %arg2, %int1_122 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.int - %int4_123 = torch.constant.int 4 - %int32_124 = torch.constant.int 32 - %int8_125 = torch.constant.int 8 - %int128_126 = torch.constant.int 128 - %399 = torch.prim.ListConstruct %int4_123, %398, %int32_124, %int8_125, %int128_126 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %400 = torch.aten.view %388, %399 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %400, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_127 = torch.constant.int 4 - %401 = torch.aten.mul.int %int4_127, %398 : !torch.int, !torch.int -> !torch.int - %int32_128 = torch.constant.int 32 - %int8_129 = torch.constant.int 8 - %int128_130 = torch.constant.int 128 - %402 = torch.prim.ListConstruct %401, %int32_128, %int8_129, %int128_130 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %403 = torch.aten.view %400, %402 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %403, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_131 = torch.constant.int 4 - %404 = torch.aten.mul.int %int4_131, %398 : !torch.int, !torch.int -> !torch.int - %405 = torch.prim.ListConstruct %404 : (!torch.int) -> !torch.list - %406 = torch.aten.view %397, %405 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %406, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %407 = torch.prim.ListConstruct %406 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_132 = torch.constant.bool false - %408 = torch.aten.index_put %395, %407, %403, %false_132 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %408, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_133 = torch.constant.int 32 - %int2_134 = torch.constant.int 2 - %int32_135 = torch.constant.int 32 - %int8_136 = torch.constant.int 8 - %int128_137 = torch.constant.int 128 - %409 = torch.prim.ListConstruct %389, %int32_133, %int2_134, %int32_135, %int8_136, %int128_137 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %410 = torch.aten.view %408, %409 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %410, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152 = torch.constant.int 2097152 - %411 = torch.prim.ListConstruct %389, %int2097152 : (!torch.int, !torch.int) -> !torch.list - %412 = torch.aten.view %410, %411 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %412, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_138 = torch.constant.int 32 - %int2_139 = torch.constant.int 2 - %int32_140 = torch.constant.int 32 - %int8_141 = torch.constant.int 8 - %int128_142 = torch.constant.int 128 - %413 = torch.prim.ListConstruct %389, %int32_138, %int2_139, %int32_140, %int8_141, %int128_142 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %414 = torch.aten.view %412, %413 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %414, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_143 = torch.constant.int 32 - %int8_144 = torch.constant.int 8 - %int128_145 = torch.constant.int 128 - %415 = torch.prim.ListConstruct %393, %int32_143, %int8_144, %int128_145 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %416 = torch.aten.view %414, %415 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %416, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_146 = torch.constant.int 4 - %int32_147 = torch.constant.int 32 - %int8_148 = torch.constant.int 8 - %int128_149 = torch.constant.int 128 - %417 = torch.prim.ListConstruct %int4_146, %398, %int32_147, %int8_148, %int128_149 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %418 = torch.aten.view %332, %417 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %418, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_123 = torch.constant.int 0 + %int9223372036854775807_124 = torch.constant.int 9223372036854775807 + %int1_125 = torch.constant.int 1 + %410 = torch.aten.slice.Tensor %409, %int1_122, %int0_123, %int9223372036854775807_124, %int1_125 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %410, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_126 = torch.constant.int 0 + %411 = torch.aten.unsqueeze %408, %int0_126 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %411, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_127 = torch.constant.int 1 + %int0_128 = torch.constant.int 0 + %int9223372036854775807_129 = torch.constant.int 9223372036854775807 + %int1_130 = torch.constant.int 1 + %412 = torch.aten.slice.Tensor %411, %int1_127, %int0_128, %int9223372036854775807_129, %int1_130 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %412, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_131 = torch.constant.int 2 + %413 = torch.aten.unsqueeze %412, %int2_131 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %413, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_132 = torch.constant.int 3 + %int0_133 = torch.constant.int 0 + %int9223372036854775807_134 = torch.constant.int 9223372036854775807 + %int1_135 = torch.constant.int 1 + %414 = torch.aten.slice.Tensor %413, %int3_132, %int0_133, %int9223372036854775807_134, %int1_135 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %414, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_136 = torch.constant.int 4 + %int1_137 = torch.constant.int 1 + %int1_138 = torch.constant.int 1 + %int1_139 = torch.constant.int 1 + %415 = torch.prim.ListConstruct %int4_136, %int1_137, %int1_138, %int1_139 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %416 = torch.aten.repeat %414, %415 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %416, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_140 = torch.constant.int 0 + %417 = torch.aten.unsqueeze %410, %int0_140 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %417, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_141 = torch.constant.int 1 + %int0_142 = torch.constant.int 0 + %int9223372036854775807_143 = torch.constant.int 9223372036854775807 + %int1_144 = torch.constant.int 1 + %418 = torch.aten.slice.Tensor %417, %int1_141, %int0_142, %int9223372036854775807_143, %int1_144 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %418, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_145 = torch.constant.int 2 + %419 = torch.aten.unsqueeze %418, %int2_145 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %419, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_146 = torch.constant.int 3 + %int0_147 = torch.constant.int 0 + %int9223372036854775807_148 = torch.constant.int 9223372036854775807 + %int1_149 = torch.constant.int 1 + %420 = torch.aten.slice.Tensor %419, %int3_146, %int0_147, %int9223372036854775807_148, %int1_149 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %420, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> %int4_150 = torch.constant.int 4 - %419 = torch.aten.mul.int %int4_150, %398 : !torch.int, !torch.int -> !torch.int - %int32_151 = torch.constant.int 32 - %int8_152 = torch.constant.int 8 - %int128_153 = torch.constant.int 128 - %420 = torch.prim.ListConstruct %419, %int32_151, %int8_152, %int128_153 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %421 = torch.aten.view %418, %420 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %421, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_154 = torch.constant.int 1 - %int1_155 = torch.constant.int 1 - %422 = torch.aten.add.Scalar %397, %int1_154, %int1_155 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %422, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_156 = torch.constant.int 4 - %423 = torch.aten.mul.int %int4_156, %398 : !torch.int, !torch.int -> !torch.int - %424 = torch.prim.ListConstruct %423 : (!torch.int) -> !torch.list - %425 = torch.aten.view %422, %424 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %425, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %426 = torch.prim.ListConstruct %425 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_157 = torch.constant.bool false - %427 = torch.aten.index_put %416, %426, %421, %false_157 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %427, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_158 = torch.constant.int 32 - %int2_159 = torch.constant.int 2 - %int32_160 = torch.constant.int 32 - %int8_161 = torch.constant.int 8 - %int128_162 = torch.constant.int 128 - %428 = torch.prim.ListConstruct %389, %int32_158, %int2_159, %int32_160, %int8_161, %int128_162 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %429 = torch.aten.view %427, %428 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %429, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_163 = torch.constant.int 2097152 - %430 = torch.prim.ListConstruct %389, %int2097152_163 : (!torch.int, !torch.int) -> !torch.list - %431 = torch.aten.view %429, %430 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %431, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_164 = torch.constant.int -2 - %432 = torch.aten.unsqueeze %388, %int-2_164 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %432, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_165 = torch.constant.int 4 - %int8_166 = torch.constant.int 8 - %int4_167 = torch.constant.int 4 - %int128_168 = torch.constant.int 128 - %433 = torch.prim.ListConstruct %int4_165, %373, %int8_166, %int4_167, %int128_168 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_169 = torch.constant.bool false - %434 = torch.aten.expand %432, %433, %false_169 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %434, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_170 = torch.constant.int 0 - %435 = torch.aten.clone %434, %int0_170 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %435, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int1_151 = torch.constant.int 1 + %int1_152 = torch.constant.int 1 + %int1_153 = torch.constant.int 1 + %421 = torch.prim.ListConstruct %int4_150, %int1_151, %int1_152, %int1_153 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %422 = torch.aten.repeat %420, %421 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %422, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %423 = torch.aten.mul.Tensor %363, %416 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %423, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_154 = torch.constant.int 3 + %int0_155 = torch.constant.int 0 + %int64 = torch.constant.int 64 + %int1_156 = torch.constant.int 1 + %424 = torch.aten.slice.Tensor %363, %int3_154, %int0_155, %int64, %int1_156 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %424, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_157 = torch.constant.int 3 + %int64_158 = torch.constant.int 64 + %int9223372036854775807_159 = torch.constant.int 9223372036854775807 + %int1_160 = torch.constant.int 1 + %425 = torch.aten.slice.Tensor %363, %int3_157, %int64_158, %int9223372036854775807_159, %int1_160 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %425, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %426 = torch.aten.neg %425 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %426, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %427 = torch.prim.ListConstruct %426, %424 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_161 = torch.constant.int -1 + %428 = torch.aten.cat %427, %int-1_161 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %428, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %429 = torch.aten.mul.Tensor %428, %422 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %429, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_162 = torch.constant.int 1 + %430 = torch.aten.add.Tensor %423, %429, %int1_162 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %430, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_163 = torch.constant.int 131072 + %none_164 = torch.constant.none + %none_165 = torch.constant.none + %cpu_166 = torch.constant.device "cpu" + %false_167 = torch.constant.bool false + %431 = torch.aten.arange %int131072_163, %none_164, %none_165, %cpu_166, %false_167 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_168 = torch.constant.int 0 + %int128_169 = torch.constant.int 128 + %int2_170 = torch.constant.int 2 %int4_171 = torch.constant.int 4 - %int32_172 = torch.constant.int 32 - %int128_173 = torch.constant.int 128 - %436 = torch.prim.ListConstruct %int4_171, %373, %int32_172, %int128_173 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %437 = torch.aten._unsafe_view %435, %436 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %437, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_174 = torch.constant.int -2 - %438 = torch.aten.unsqueeze %332, %int-2_174 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %438, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_175 = torch.constant.int 1 - %439 = torch.aten.size.int %326, %int1_175 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_176 = torch.constant.int 4 - %int8_177 = torch.constant.int 8 - %int4_178 = torch.constant.int 4 - %int128_179 = torch.constant.int 128 - %440 = torch.prim.ListConstruct %int4_176, %439, %int8_177, %int4_178, %int128_179 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_180 = torch.constant.bool false - %441 = torch.aten.expand %438, %440, %false_180 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %441, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_181 = torch.constant.int 0 - %442 = torch.aten.clone %441, %int0_181 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %442, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_182 = torch.constant.int 4 - %int32_183 = torch.constant.int 32 - %int128_184 = torch.constant.int 128 - %443 = torch.prim.ListConstruct %int4_182, %439, %int32_183, %int128_184 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %444 = torch.aten._unsafe_view %442, %443 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %444, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_185 = torch.constant.int 1 - %int2_186 = torch.constant.int 2 - %445 = torch.aten.transpose.int %360, %int1_185, %int2_186 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %445, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %none_172 = torch.constant.none + %cpu_173 = torch.constant.device "cpu" + %false_174 = torch.constant.bool false + %432 = torch.aten.arange.start_step %int0_168, %int128_169, %int2_170, %int4_171, %none_172, %cpu_173, %false_174 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_175 = torch.constant.int 6 + %433 = torch.prims.convert_element_type %432, %int6_175 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_176 = torch.constant.int 128 + %434 = torch.aten.div.Scalar %433, %int128_176 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_177 = torch.constant.float 5.000000e+05 + %435 = torch.aten.pow.Scalar %float5.000000e05_177, %434 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %436 = torch.aten.reciprocal %435 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_178 = torch.constant.float 1.000000e+00 + %437 = torch.aten.mul.Scalar %436, %float1.000000e00_178 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %438 = torch.aten.reciprocal %437 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_179 = torch.constant.float 6.2831853071795862 + %439 = torch.aten.mul.Scalar %438, %float6.283190e00_179 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_180 = torch.constant.float 8.192000e+03 + %440 = torch.aten.gt.Scalar %439, %float8.192000e03_180 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_181 = torch.constant.int 8 + %441 = torch.aten.div.Scalar %437, %int8_181 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %442 = torch.aten.where.self %440, %441, %437 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %443 = torch.aten.reciprocal %439 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_182 = torch.constant.int 8192 + %444 = torch.aten.mul.Scalar %443, %int8192_182 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_183 = torch.constant.int 1 + %int1_184 = torch.constant.int 1 + %445 = torch.aten.sub.Scalar %444, %int1_183, %int1_184 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_185 = torch.constant.int 3 + %446 = torch.aten.div.Scalar %445, %int3_185 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_186 = torch.constant.int 1 %int1_187 = torch.constant.int 1 - %int2_188 = torch.constant.int 2 - %446 = torch.aten.transpose.int %437, %int1_187, %int2_188 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %446, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %447 = torch.aten.rsub.Scalar %446, %int1_186, %int1_187 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %448 = torch.aten.mul.Tensor %447, %442 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_188 = torch.constant.int 8 + %449 = torch.aten.div.Scalar %448, %int8_188 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %450 = torch.aten.mul.Tensor %446, %442 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> %int1_189 = torch.constant.int 1 - %int2_190 = torch.constant.int 2 - %447 = torch.aten.transpose.int %444, %int1_189, %int2_190 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %447, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00 = torch.constant.float 0.000000e+00 - %true_191 = torch.constant.bool true - %none_192 = torch.constant.none - %none_193 = torch.constant.none - %448:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%445, %446, %447, %float0.000000e00, %true_191, %none_192, %none_193) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %448#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %451 = torch.aten.add.Tensor %449, %450, %int1_189 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_190 = torch.constant.float 2.048000e+03 + %452 = torch.aten.lt.Scalar %439, %float2.048000e03_190 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %453 = torch.aten.bitwise_not %452 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_191 = torch.constant.float 8.192000e+03 + %454 = torch.aten.gt.Scalar %439, %float8.192000e03_191 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %455 = torch.aten.bitwise_not %454 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %456 = torch.aten.mul.Tensor %453, %455 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %457 = torch.aten.where.self %456, %451, %442 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %458 = torch.prim.ListConstruct %457, %457 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_192 = torch.constant.int -1 + %459 = torch.aten.cat %458, %int-1_192 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_193 = torch.constant.int 6 + %460 = torch.prims.convert_element_type %459, %int6_193 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> %int1_194 = torch.constant.int 1 - %int2_195 = torch.constant.int 2 - %449 = torch.aten.transpose.int %448#0, %int1_194, %int2_195 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %449, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_196 = torch.constant.int 4 - %int4096_197 = torch.constant.int 4096 - %450 = torch.prim.ListConstruct %int4_196, %345, %int4096_197 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %451 = torch.aten.view %449, %450 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %451, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_198 = torch.constant.int -2 - %int-1_199 = torch.constant.int -1 - %452 = torch.aten.transpose.int %5, %int-2_198, %int-1_199 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_200 = torch.constant.int 4 - %453 = torch.aten.mul.int %int4_200, %345 : !torch.int, !torch.int -> !torch.int - %int4096_201 = torch.constant.int 4096 - %454 = torch.prim.ListConstruct %453, %int4096_201 : (!torch.int, !torch.int) -> !torch.list - %455 = torch.aten.view %451, %454 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %455, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %456 = torch.aten.mm %455, %452 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %456, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_202 = torch.constant.int 4 - %int4096_203 = torch.constant.int 4096 - %457 = torch.prim.ListConstruct %int4_202, %345, %int4096_203 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %458 = torch.aten.view %456, %457 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %458, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_204 = torch.constant.int 1 - %459 = torch.aten.add.Tensor %294, %458, %int1_204 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %459, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_205 = torch.constant.int 6 - %460 = torch.prims.convert_element_type %459, %int6_205 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %460, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_206 = torch.constant.int 2 - %461 = torch.aten.pow.Tensor_Scalar %460, %int2_206 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %461, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_207 = torch.constant.int -1 - %462 = torch.prim.ListConstruct %int-1_207 : (!torch.int) -> !torch.list - %true_208 = torch.constant.bool true - %none_209 = torch.constant.none - %463 = torch.aten.mean.dim %461, %462, %true_208, %none_209 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %463, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_210 = torch.constant.float 9.9999997473787516E-6 - %int1_211 = torch.constant.int 1 - %464 = torch.aten.add.Scalar %463, %float9.999990e-06_210, %int1_211 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %464, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %465 = torch.aten.rsqrt %464 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %465, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %466 = torch.aten.mul.Tensor %460, %465 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %466, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_212 = torch.constant.int 5 - %467 = torch.prims.convert_element_type %466, %int5_212 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %467, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %468 = torch.aten.mul.Tensor %6, %467 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %468, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_213 = torch.constant.int 5 - %469 = torch.prims.convert_element_type %468, %int5_213 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %469, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_214 = torch.constant.int -2 - %int-1_215 = torch.constant.int -1 - %470 = torch.aten.transpose.int %7, %int-2_214, %int-1_215 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_216 = torch.constant.int 4 - %471 = torch.aten.mul.int %int4_216, %306 : !torch.int, !torch.int -> !torch.int - %int4096_217 = torch.constant.int 4096 - %472 = torch.prim.ListConstruct %471, %int4096_217 : (!torch.int, !torch.int) -> !torch.list - %473 = torch.aten.view %469, %472 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %473, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %474 = torch.aten.mm %473, %470 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %474, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_218 = torch.constant.int 4 - %int14336 = torch.constant.int 14336 - %475 = torch.prim.ListConstruct %int4_218, %306, %int14336 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %476 = torch.aten.view %474, %475 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %476, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %477 = torch.aten.silu %476 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %477, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_219 = torch.constant.int -2 - %int-1_220 = torch.constant.int -1 - %478 = torch.aten.transpose.int %8, %int-2_219, %int-1_220 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_221 = torch.constant.int 4 - %479 = torch.aten.mul.int %int4_221, %306 : !torch.int, !torch.int -> !torch.int - %int4096_222 = torch.constant.int 4096 - %480 = torch.prim.ListConstruct %479, %int4096_222 : (!torch.int, !torch.int) -> !torch.list - %481 = torch.aten.view %469, %480 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %481, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %482 = torch.aten.mm %481, %478 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %482, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_223 = torch.constant.int 4 - %int14336_224 = torch.constant.int 14336 - %483 = torch.prim.ListConstruct %int4_223, %306, %int14336_224 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %484 = torch.aten.view %482, %483 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %484, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %485 = torch.aten.mul.Tensor %477, %484 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %485, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_225 = torch.constant.int -2 - %int-1_226 = torch.constant.int -1 - %486 = torch.aten.transpose.int %9, %int-2_225, %int-1_226 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %461 = torch.aten.unsqueeze %431, %int1_194 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_195 = torch.constant.int 6 + %462 = torch.prims.convert_element_type %461, %int6_195 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_196 = torch.constant.int 0 + %463 = torch.aten.unsqueeze %460, %int0_196 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_197 = torch.constant.int 6 + %464 = torch.prims.convert_element_type %463, %int6_197 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %465 = torch.aten.mul.Tensor %462, %464 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %466 = torch.aten.cos %465 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_198 = torch.constant.int 5 + %467 = torch.prims.convert_element_type %466, %int5_198 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %468 = torch.aten.sin %465 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_199 = torch.constant.int 5 + %469 = torch.prims.convert_element_type %468, %int5_199 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_200 = torch.constant.int 0 + %int0_201 = torch.constant.int 0 + %int1_202 = torch.constant.int 1 + %470 = torch.aten.slice.Tensor %467, %int0_200, %int0_201, %298, %int1_202 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %470, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_203 = torch.constant.int 1 + %int0_204 = torch.constant.int 0 + %int9223372036854775807_205 = torch.constant.int 9223372036854775807 + %int1_206 = torch.constant.int 1 + %471 = torch.aten.slice.Tensor %470, %int1_203, %int0_204, %int9223372036854775807_205, %int1_206 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %471, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_207 = torch.constant.int 0 + %int0_208 = torch.constant.int 0 + %int1_209 = torch.constant.int 1 + %472 = torch.aten.slice.Tensor %469, %int0_207, %int0_208, %298, %int1_209 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %472, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_210 = torch.constant.int 1 + %int0_211 = torch.constant.int 0 + %int9223372036854775807_212 = torch.constant.int 9223372036854775807 + %int1_213 = torch.constant.int 1 + %473 = torch.aten.slice.Tensor %472, %int1_210, %int0_211, %int9223372036854775807_212, %int1_213 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %473, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_214 = torch.constant.int 0 + %474 = torch.aten.unsqueeze %471, %int0_214 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %474, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_215 = torch.constant.int 1 + %int0_216 = torch.constant.int 0 + %int9223372036854775807_217 = torch.constant.int 9223372036854775807 + %int1_218 = torch.constant.int 1 + %475 = torch.aten.slice.Tensor %474, %int1_215, %int0_216, %int9223372036854775807_217, %int1_218 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %475, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_219 = torch.constant.int 2 + %476 = torch.aten.unsqueeze %475, %int2_219 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %476, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_220 = torch.constant.int 3 + %int0_221 = torch.constant.int 0 + %int9223372036854775807_222 = torch.constant.int 9223372036854775807 + %int1_223 = torch.constant.int 1 + %477 = torch.aten.slice.Tensor %476, %int3_220, %int0_221, %int9223372036854775807_222, %int1_223 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %477, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_224 = torch.constant.int 4 + %int1_225 = torch.constant.int 1 + %int1_226 = torch.constant.int 1 %int1_227 = torch.constant.int 1 - %487 = torch.aten.size.int %476, %int1_227 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_228 = torch.constant.int 4 - %488 = torch.aten.mul.int %int4_228, %487 : !torch.int, !torch.int -> !torch.int - %int14336_229 = torch.constant.int 14336 - %489 = torch.prim.ListConstruct %488, %int14336_229 : (!torch.int, !torch.int) -> !torch.list - %490 = torch.aten.view %485, %489 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %490, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %491 = torch.aten.mm %490, %486 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %491, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_230 = torch.constant.int 4 - %int4096_231 = torch.constant.int 4096 - %492 = torch.prim.ListConstruct %int4_230, %487, %int4096_231 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %493 = torch.aten.view %491, %492 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %493, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %478 = torch.prim.ListConstruct %int4_224, %int1_225, %int1_226, %int1_227 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %479 = torch.aten.repeat %477, %478 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %479, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_228 = torch.constant.int 0 + %480 = torch.aten.unsqueeze %473, %int0_228 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %480, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_229 = torch.constant.int 1 + %int0_230 = torch.constant.int 0 + %int9223372036854775807_231 = torch.constant.int 9223372036854775807 %int1_232 = torch.constant.int 1 - %494 = torch.aten.add.Tensor %459, %493, %int1_232 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %494, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_233 = torch.constant.int 6 - %495 = torch.prims.convert_element_type %494, %int6_233 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %495, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_234 = torch.constant.int 2 - %496 = torch.aten.pow.Tensor_Scalar %495, %int2_234 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %496, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_235 = torch.constant.int -1 - %497 = torch.prim.ListConstruct %int-1_235 : (!torch.int) -> !torch.list - %true_236 = torch.constant.bool true - %none_237 = torch.constant.none - %498 = torch.aten.mean.dim %496, %497, %true_236, %none_237 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %498, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_238 = torch.constant.float 9.9999997473787516E-6 + %481 = torch.aten.slice.Tensor %480, %int1_229, %int0_230, %int9223372036854775807_231, %int1_232 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %481, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_233 = torch.constant.int 2 + %482 = torch.aten.unsqueeze %481, %int2_233 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %482, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_234 = torch.constant.int 3 + %int0_235 = torch.constant.int 0 + %int9223372036854775807_236 = torch.constant.int 9223372036854775807 + %int1_237 = torch.constant.int 1 + %483 = torch.aten.slice.Tensor %482, %int3_234, %int0_235, %int9223372036854775807_236, %int1_237 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %483, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_238 = torch.constant.int 4 %int1_239 = torch.constant.int 1 - %499 = torch.aten.add.Scalar %498, %float9.999990e-06_238, %int1_239 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %499, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %500 = torch.aten.rsqrt %499 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %500, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %501 = torch.aten.mul.Tensor %495, %500 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %501, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_240 = torch.constant.int 5 - %502 = torch.prims.convert_element_type %501, %int5_240 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %502, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %503 = torch.aten.mul.Tensor %10, %502 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %503, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_241 = torch.constant.int 5 - %504 = torch.prims.convert_element_type %503, %int5_241 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %504, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_242 = torch.constant.int -2 - %int-1_243 = torch.constant.int -1 - %505 = torch.aten.transpose.int %11, %int-2_242, %int-1_243 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_244 = torch.constant.int 4 - %506 = torch.aten.mul.int %int4_244, %306 : !torch.int, !torch.int -> !torch.int - %int4096_245 = torch.constant.int 4096 - %507 = torch.prim.ListConstruct %506, %int4096_245 : (!torch.int, !torch.int) -> !torch.list - %508 = torch.aten.view %504, %507 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %508, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %509 = torch.aten.mm %508, %505 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %509, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_246 = torch.constant.int 4 - %int4096_247 = torch.constant.int 4096 - %510 = torch.prim.ListConstruct %int4_246, %306, %int4096_247 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %511 = torch.aten.view %509, %510 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %511, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_248 = torch.constant.int -2 - %int-1_249 = torch.constant.int -1 - %512 = torch.aten.transpose.int %12, %int-2_248, %int-1_249 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_250 = torch.constant.int 4 - %513 = torch.aten.mul.int %int4_250, %306 : !torch.int, !torch.int -> !torch.int - %int4096_251 = torch.constant.int 4096 - %514 = torch.prim.ListConstruct %513, %int4096_251 : (!torch.int, !torch.int) -> !torch.list - %515 = torch.aten.view %504, %514 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %515, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %516 = torch.aten.mm %515, %512 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %516, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_252 = torch.constant.int 4 - %int1024_253 = torch.constant.int 1024 - %517 = torch.prim.ListConstruct %int4_252, %306, %int1024_253 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %518 = torch.aten.view %516, %517 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %518, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_254 = torch.constant.int -2 - %int-1_255 = torch.constant.int -1 - %519 = torch.aten.transpose.int %13, %int-2_254, %int-1_255 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_256 = torch.constant.int 4 - %520 = torch.aten.mul.int %int4_256, %306 : !torch.int, !torch.int -> !torch.int - %int4096_257 = torch.constant.int 4096 - %521 = torch.prim.ListConstruct %520, %int4096_257 : (!torch.int, !torch.int) -> !torch.list - %522 = torch.aten.view %504, %521 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %522, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %523 = torch.aten.mm %522, %519 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %523, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_258 = torch.constant.int 4 - %int1024_259 = torch.constant.int 1024 - %524 = torch.prim.ListConstruct %int4_258, %306, %int1024_259 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %525 = torch.aten.view %523, %524 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %525, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_260 = torch.constant.int 4 - %int32_261 = torch.constant.int 32 - %int128_262 = torch.constant.int 128 - %526 = torch.prim.ListConstruct %int4_260, %306, %int32_261, %int128_262 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %527 = torch.aten.view %511, %526 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %527, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_263 = torch.constant.int 4 - %int8_264 = torch.constant.int 8 - %int128_265 = torch.constant.int 128 - %528 = torch.prim.ListConstruct %int4_263, %306, %int8_264, %int128_265 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %529 = torch.aten.view %518, %528 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %529, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_266 = torch.constant.int 4 - %int8_267 = torch.constant.int 8 - %int128_268 = torch.constant.int 128 - %530 = torch.prim.ListConstruct %int4_266, %306, %int8_267, %int128_268 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %531 = torch.aten.view %525, %530 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %531, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_269 = torch.constant.int 131072 - %none_270 = torch.constant.none - %none_271 = torch.constant.none - %cpu_272 = torch.constant.device "cpu" - %false_273 = torch.constant.bool false - %532 = torch.aten.arange %int131072_269, %none_270, %none_271, %cpu_272, %false_273 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_274 = torch.constant.int 0 + %int1_240 = torch.constant.int 1 + %int1_241 = torch.constant.int 1 + %484 = torch.prim.ListConstruct %int4_238, %int1_239, %int1_240, %int1_241 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %485 = torch.aten.repeat %483, %484 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %485, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %486 = torch.aten.mul.Tensor %365, %479 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %486, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_242 = torch.constant.int 3 + %int0_243 = torch.constant.int 0 + %int64_244 = torch.constant.int 64 + %int1_245 = torch.constant.int 1 + %487 = torch.aten.slice.Tensor %365, %int3_242, %int0_243, %int64_244, %int1_245 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %487, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_246 = torch.constant.int 3 + %int64_247 = torch.constant.int 64 + %int9223372036854775807_248 = torch.constant.int 9223372036854775807 + %int1_249 = torch.constant.int 1 + %488 = torch.aten.slice.Tensor %365, %int3_246, %int64_247, %int9223372036854775807_248, %int1_249 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %488, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %489 = torch.aten.neg %488 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %489, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %490 = torch.prim.ListConstruct %489, %487 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_250 = torch.constant.int -1 + %491 = torch.aten.cat %490, %int-1_250 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %491, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %492 = torch.aten.mul.Tensor %491, %485 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %492, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_251 = torch.constant.int 1 + %493 = torch.aten.add.Tensor %486, %492, %int1_251 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %493, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_252 = torch.constant.int 32 + %int2_253 = torch.constant.int 2 + %int8_254 = torch.constant.int 8 + %int32_255 = torch.constant.int 32 + %int128_256 = torch.constant.int 128 + %494 = torch.prim.ListConstruct %297, %int32_252, %int2_253, %int8_254, %int32_255, %int128_256 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %495 = torch.aten.view %292, %494 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %495, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int32_257 = torch.constant.int 32 + %496 = torch.aten.mul.int %297, %int32_257 : !torch.int, !torch.int -> !torch.int + %int2_258 = torch.constant.int 2 + %497 = torch.aten.mul.int %496, %int2_258 : !torch.int, !torch.int -> !torch.int + %int8_259 = torch.constant.int 8 + %int32_260 = torch.constant.int 32 + %int128_261 = torch.constant.int 128 + %498 = torch.prim.ListConstruct %497, %int8_259, %int32_260, %int128_261 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %499 = torch.aten.view %495, %498 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %499, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_262 = torch.constant.int 32 + %500 = torch.aten.mul.Scalar %arg2, %int32_262 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %500, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_263 = torch.constant.int 0 + %int1_264 = torch.constant.int 1 + %501 = torch.aten.add.Scalar %500, %int0_263, %int1_264 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %501, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_265 = torch.constant.int 2 + %502 = torch.aten.mul.Scalar %501, %int2_265 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %502, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_266 = torch.constant.int 0 + %int1_267 = torch.constant.int 1 + %503 = torch.aten.add.Scalar %502, %int0_266, %int1_267 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %503, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int4_268 = torch.constant.int 4 + %504 = torch.aten.mul.int %int4_268, %296 : !torch.int, !torch.int -> !torch.int + %505 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %506 = torch.aten.view %503, %505 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %506, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_269 = torch.constant.int 4 + %int32_270 = torch.constant.int 32 + %int8_271 = torch.constant.int 8 + %int128_272 = torch.constant.int 128 + %507 = torch.prim.ListConstruct %int4_269, %296, %int32_270, %int8_271, %int128_272 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %508 = torch.aten.view %493, %507 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %508, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_273 = torch.constant.int 32 + %int8_274 = torch.constant.int 8 %int128_275 = torch.constant.int 128 - %none_276 = torch.constant.none - %none_277 = torch.constant.none - %cpu_278 = torch.constant.device "cpu" + %509 = torch.prim.ListConstruct %504, %int32_273, %int8_274, %int128_275 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %510 = torch.aten.view %508, %509 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %510, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_276 = torch.constant.int 1 + %int2_277 = torch.constant.int 2 + %511 = torch.aten.transpose.int %510, %int1_276, %int2_277 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %511, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_278 = torch.constant.int 5 + %512 = torch.prims.convert_element_type %511, %int5_278 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %512, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %513 = torch.prim.ListConstruct %506 : (!torch.vtensor<[?],si64>) -> !torch.list> %false_279 = torch.constant.bool false - %533 = torch.aten.arange.start %int0_274, %int128_275, %none_276, %none_277, %cpu_278, %false_279 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_280 = torch.constant.int 2 - %534 = torch.aten.floor_divide.Scalar %533, %int2_280 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_281 = torch.constant.int 6 - %535 = torch.prims.convert_element_type %534, %int6_281 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_282 = torch.constant.int 128 - %536 = torch.aten.div.Scalar %535, %int128_282 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_283 = torch.constant.float 2.000000e+00 - %537 = torch.aten.mul.Scalar %536, %float2.000000e00_283 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_284 = torch.constant.float 5.000000e+05 - %538 = torch.aten.pow.Scalar %float5.000000e05_284, %537 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %539 = torch.aten.reciprocal %538 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_285 = torch.constant.float 1.000000e+00 - %540 = torch.aten.mul.Scalar %539, %float1.000000e00_285 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_286 = torch.constant.int 1 - %541 = torch.aten.unsqueeze %532, %int1_286 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_287 = torch.constant.int 0 - %542 = torch.aten.unsqueeze %540, %int0_287 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %543 = torch.aten.mul.Tensor %541, %542 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_288 = torch.constant.int 1 - %544 = torch.aten.size.int %511, %int1_288 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_289 = torch.constant.int 0 - %545 = torch.aten.add.int %int0_289, %544 : !torch.int, !torch.int -> !torch.int - %int0_290 = torch.constant.int 0 - %int0_291 = torch.constant.int 0 - %int1_292 = torch.constant.int 1 - %546 = torch.aten.slice.Tensor %543, %int0_290, %int0_291, %545, %int1_292 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %546, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_293 = torch.constant.int 1 + %514 = torch.aten.index_put %499, %513, %512, %false_279 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %514, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_280 = torch.constant.int 32 + %int2_281 = torch.constant.int 2 + %int8_282 = torch.constant.int 8 + %int32_283 = torch.constant.int 32 + %int128_284 = torch.constant.int 128 + %515 = torch.prim.ListConstruct %297, %int32_280, %int2_281, %int8_282, %int32_283, %int128_284 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %516 = torch.aten.view %514, %515 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %516, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152 = torch.constant.int 2097152 + %517 = torch.prim.ListConstruct %297, %int2097152 : (!torch.int, !torch.int) -> !torch.list + %518 = torch.aten.view %516, %517 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %518, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_285 = torch.constant.int 32 + %int2_286 = torch.constant.int 2 + %int8_287 = torch.constant.int 8 + %int32_288 = torch.constant.int 32 + %int128_289 = torch.constant.int 128 + %519 = torch.prim.ListConstruct %297, %int32_285, %int2_286, %int8_287, %int32_288, %int128_289 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %520 = torch.aten.view %518, %519 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %520, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_290 = torch.constant.int 8 + %int32_291 = torch.constant.int 32 + %int128_292 = torch.constant.int 128 + %521 = torch.prim.ListConstruct %497, %int8_290, %int32_291, %int128_292 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %522 = torch.aten.view %520, %521 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %522, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_293 = torch.constant.int 32 + %523 = torch.aten.mul.Scalar %arg2, %int32_293 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %523, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> %int0_294 = torch.constant.int 0 - %int9223372036854775807_295 = torch.constant.int 9223372036854775807 - %int1_296 = torch.constant.int 1 - %547 = torch.aten.slice.Tensor %546, %int1_293, %int0_294, %int9223372036854775807_295, %int1_296 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %547, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %int1_295 = torch.constant.int 1 + %524 = torch.aten.add.Scalar %523, %int0_294, %int1_295 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %524, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_296 = torch.constant.int 2 + %525 = torch.aten.mul.Scalar %524, %int2_296 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %525, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> %int1_297 = torch.constant.int 1 - %int0_298 = torch.constant.int 0 - %int9223372036854775807_299 = torch.constant.int 9223372036854775807 - %int1_300 = torch.constant.int 1 - %548 = torch.aten.slice.Tensor %547, %int1_297, %int0_298, %int9223372036854775807_299, %int1_300 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %548, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_301 = torch.constant.int 0 - %549 = torch.aten.unsqueeze %548, %int0_301 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %549, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_302 = torch.constant.int 1 - %int0_303 = torch.constant.int 0 - %int9223372036854775807_304 = torch.constant.int 9223372036854775807 - %int1_305 = torch.constant.int 1 - %550 = torch.aten.slice.Tensor %549, %int1_302, %int0_303, %int9223372036854775807_304, %int1_305 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %550, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_306 = torch.constant.int 2 - %int0_307 = torch.constant.int 0 - %int9223372036854775807_308 = torch.constant.int 9223372036854775807 - %int1_309 = torch.constant.int 1 - %551 = torch.aten.slice.Tensor %550, %int2_306, %int0_307, %int9223372036854775807_308, %int1_309 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %551, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_310 = torch.constant.int 4 - %int1_311 = torch.constant.int 1 - %int1_312 = torch.constant.int 1 - %552 = torch.prim.ListConstruct %int4_310, %int1_311, %int1_312 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %553 = torch.aten.repeat %551, %552 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %553, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_313 = torch.constant.int 6 - %554 = torch.prims.convert_element_type %527, %int6_313 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %554, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %555 = torch_c.to_builtin_tensor %554 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %556 = torch_c.to_builtin_tensor %553 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %557 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%555, %556) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %558 = torch_c.from_builtin_tensor %557 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %558, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_314 = torch.constant.int 5 - %559 = torch.prims.convert_element_type %558, %int5_314 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %559, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_315 = torch.constant.int 131072 - %none_316 = torch.constant.none - %none_317 = torch.constant.none - %cpu_318 = torch.constant.device "cpu" - %false_319 = torch.constant.bool false - %560 = torch.aten.arange %int131072_315, %none_316, %none_317, %cpu_318, %false_319 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_320 = torch.constant.int 0 - %int128_321 = torch.constant.int 128 - %none_322 = torch.constant.none - %none_323 = torch.constant.none - %cpu_324 = torch.constant.device "cpu" - %false_325 = torch.constant.bool false - %561 = torch.aten.arange.start %int0_320, %int128_321, %none_322, %none_323, %cpu_324, %false_325 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_326 = torch.constant.int 2 - %562 = torch.aten.floor_divide.Scalar %561, %int2_326 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_327 = torch.constant.int 6 - %563 = torch.prims.convert_element_type %562, %int6_327 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_328 = torch.constant.int 128 - %564 = torch.aten.div.Scalar %563, %int128_328 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_329 = torch.constant.float 2.000000e+00 - %565 = torch.aten.mul.Scalar %564, %float2.000000e00_329 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_330 = torch.constant.float 5.000000e+05 - %566 = torch.aten.pow.Scalar %float5.000000e05_330, %565 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %567 = torch.aten.reciprocal %566 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_331 = torch.constant.float 1.000000e+00 - %568 = torch.aten.mul.Scalar %567, %float1.000000e00_331 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_332 = torch.constant.int 1 - %569 = torch.aten.unsqueeze %560, %int1_332 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_333 = torch.constant.int 0 - %570 = torch.aten.unsqueeze %568, %int0_333 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %571 = torch.aten.mul.Tensor %569, %570 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_334 = torch.constant.int 1 - %572 = torch.aten.size.int %518, %int1_334 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_335 = torch.constant.int 0 - %573 = torch.aten.add.int %int0_335, %572 : !torch.int, !torch.int -> !torch.int - %int0_336 = torch.constant.int 0 - %int0_337 = torch.constant.int 0 + %int1_298 = torch.constant.int 1 + %526 = torch.aten.add.Scalar %525, %int1_297, %int1_298 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %526, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %527 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %528 = torch.aten.view %526, %527 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %528, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_299 = torch.constant.int 4 + %int32_300 = torch.constant.int 32 + %int8_301 = torch.constant.int 8 + %int128_302 = torch.constant.int 128 + %529 = torch.prim.ListConstruct %int4_299, %296, %int32_300, %int8_301, %int128_302 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %530 = torch.aten.view %367, %529 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %530, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_303 = torch.constant.int 32 + %int8_304 = torch.constant.int 8 + %int128_305 = torch.constant.int 128 + %531 = torch.prim.ListConstruct %504, %int32_303, %int8_304, %int128_305 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %532 = torch.aten.view %530, %531 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %532, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_306 = torch.constant.int 1 + %int2_307 = torch.constant.int 2 + %533 = torch.aten.transpose.int %532, %int1_306, %int2_307 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %533, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_308 = torch.constant.int 5 + %534 = torch.prims.convert_element_type %533, %int5_308 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %534, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %535 = torch.prim.ListConstruct %528 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_309 = torch.constant.bool false + %536 = torch.aten.index_put %522, %535, %534, %false_309 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %536, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_310 = torch.constant.int 32 + %int2_311 = torch.constant.int 2 + %int8_312 = torch.constant.int 8 + %int32_313 = torch.constant.int 32 + %int128_314 = torch.constant.int 128 + %537 = torch.prim.ListConstruct %297, %int32_310, %int2_311, %int8_312, %int32_313, %int128_314 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %538 = torch.aten.view %536, %537 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %538, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_315 = torch.constant.int 2097152 + %539 = torch.prim.ListConstruct %297, %int2097152_315 : (!torch.int, !torch.int) -> !torch.list + %540 = torch.aten.view %538, %539 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %540, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_316 = torch.constant.int -2 + %541 = torch.aten.unsqueeze %493, %int-2_316 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %541, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_317 = torch.constant.int 4 + %int8_318 = torch.constant.int 8 + %int4_319 = torch.constant.int 4 + %int128_320 = torch.constant.int 128 + %542 = torch.prim.ListConstruct %int4_317, %298, %int8_318, %int4_319, %int128_320 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_321 = torch.constant.bool false + %543 = torch.aten.expand %541, %542, %false_321 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %543, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_322 = torch.constant.int 0 + %544 = torch.aten.clone %543, %int0_322 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %544, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_323 = torch.constant.int 4 + %int32_324 = torch.constant.int 32 + %int128_325 = torch.constant.int 128 + %545 = torch.prim.ListConstruct %int4_323, %298, %int32_324, %int128_325 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %546 = torch.aten._unsafe_view %544, %545 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %546, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_326 = torch.constant.int -2 + %547 = torch.aten.unsqueeze %367, %int-2_326 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %547, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_327 = torch.constant.int 4 + %int8_328 = torch.constant.int 8 + %int4_329 = torch.constant.int 4 + %int128_330 = torch.constant.int 128 + %548 = torch.prim.ListConstruct %int4_327, %298, %int8_328, %int4_329, %int128_330 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_331 = torch.constant.bool false + %549 = torch.aten.expand %547, %548, %false_331 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %549, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_332 = torch.constant.int 0 + %550 = torch.aten.clone %549, %int0_332 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %550, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_333 = torch.constant.int 4 + %int32_334 = torch.constant.int 32 + %int128_335 = torch.constant.int 128 + %551 = torch.prim.ListConstruct %int4_333, %298, %int32_334, %int128_335 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %552 = torch.aten._unsafe_view %550, %551 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %552, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_336 = torch.constant.int 1 + %int2_337 = torch.constant.int 2 + %553 = torch.aten.transpose.int %430, %int1_336, %int2_337 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %553, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_338 = torch.constant.int 1 - %574 = torch.aten.slice.Tensor %571, %int0_336, %int0_337, %573, %int1_338 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %574, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_339 = torch.constant.int 1 - %int0_340 = torch.constant.int 0 - %int9223372036854775807_341 = torch.constant.int 9223372036854775807 - %int1_342 = torch.constant.int 1 - %575 = torch.aten.slice.Tensor %574, %int1_339, %int0_340, %int9223372036854775807_341, %int1_342 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %575, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_343 = torch.constant.int 1 - %int0_344 = torch.constant.int 0 - %int9223372036854775807_345 = torch.constant.int 9223372036854775807 - %int1_346 = torch.constant.int 1 - %576 = torch.aten.slice.Tensor %575, %int1_343, %int0_344, %int9223372036854775807_345, %int1_346 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %576, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_347 = torch.constant.int 0 - %577 = torch.aten.unsqueeze %576, %int0_347 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %577, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_348 = torch.constant.int 1 - %int0_349 = torch.constant.int 0 - %int9223372036854775807_350 = torch.constant.int 9223372036854775807 - %int1_351 = torch.constant.int 1 - %578 = torch.aten.slice.Tensor %577, %int1_348, %int0_349, %int9223372036854775807_350, %int1_351 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %578, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_352 = torch.constant.int 2 - %int0_353 = torch.constant.int 0 - %int9223372036854775807_354 = torch.constant.int 9223372036854775807 - %int1_355 = torch.constant.int 1 - %579 = torch.aten.slice.Tensor %578, %int2_352, %int0_353, %int9223372036854775807_354, %int1_355 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %579, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_356 = torch.constant.int 4 - %int1_357 = torch.constant.int 1 - %int1_358 = torch.constant.int 1 - %580 = torch.prim.ListConstruct %int4_356, %int1_357, %int1_358 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %581 = torch.aten.repeat %579, %580 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %581, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_359 = torch.constant.int 6 - %582 = torch.prims.convert_element_type %529, %int6_359 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %582, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %583 = torch_c.to_builtin_tensor %582 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %584 = torch_c.to_builtin_tensor %581 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %585 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%583, %584) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %586 = torch_c.from_builtin_tensor %585 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %586, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_360 = torch.constant.int 5 - %587 = torch.prims.convert_element_type %586, %int5_360 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %587, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_361 = torch.constant.int 64 - %588 = torch.aten.mul.Scalar %arg2, %int64_361 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %588, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int2_362 = torch.constant.int 2 - %int1_363 = torch.constant.int 1 - %589 = torch.aten.add.Scalar %588, %int2_362, %int1_363 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %589, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_364 = torch.constant.int 4 - %int32_365 = torch.constant.int 32 - %int8_366 = torch.constant.int 8 - %int128_367 = torch.constant.int 128 - %590 = torch.prim.ListConstruct %int4_364, %398, %int32_365, %int8_366, %int128_367 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %591 = torch.aten.view %587, %590 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %591, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int2_339 = torch.constant.int 2 + %554 = torch.aten.transpose.int %546, %int1_338, %int2_339 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %554, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_340 = torch.constant.int 1 + %int2_341 = torch.constant.int 2 + %555 = torch.aten.transpose.int %552, %int1_340, %int2_341 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %555, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00 = torch.constant.float 0.000000e+00 + %false_342 = torch.constant.bool false + %none_343 = torch.constant.none + %556:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%553, %554, %555, %float0.000000e00, %false_342, %327, %none_343) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %556#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_344 = torch.constant.int 1 + %int2_345 = torch.constant.int 2 + %557 = torch.aten.transpose.int %556#0, %int1_344, %int2_345 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %557, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_346 = torch.constant.int 4 + %int4096_347 = torch.constant.int 4096 + %558 = torch.prim.ListConstruct %int4_346, %298, %int4096_347 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %559 = torch.aten.view %557, %558 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %559, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_348 = torch.constant.int -2 + %int-1_349 = torch.constant.int -1 + %560 = torch.aten.transpose.int %6, %int-2_348, %int-1_349 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_350 = torch.constant.int 5 + %561 = torch.prims.convert_element_type %560, %int5_350 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_351 = torch.constant.int 4096 + %562 = torch.prim.ListConstruct %342, %int4096_351 : (!torch.int, !torch.int) -> !torch.list + %563 = torch.aten.view %559, %562 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %563, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %564 = torch.aten.mm %563, %561 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %564, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_352 = torch.constant.int 4 + %int4096_353 = torch.constant.int 4096 + %565 = torch.prim.ListConstruct %int4_352, %298, %int4096_353 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %566 = torch.aten.view %564, %565 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %566, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_354 = torch.constant.int 1 + %567 = torch.aten.add.Tensor %329, %566, %int1_354 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %567, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_355 = torch.constant.int 6 + %568 = torch.prims.convert_element_type %567, %int6_355 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %568, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_356 = torch.constant.int 2 + %569 = torch.aten.pow.Tensor_Scalar %568, %int2_356 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %569, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_357 = torch.constant.int -1 + %570 = torch.prim.ListConstruct %int-1_357 : (!torch.int) -> !torch.list + %true_358 = torch.constant.bool true + %none_359 = torch.constant.none + %571 = torch.aten.mean.dim %569, %570, %true_358, %none_359 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %571, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_360 = torch.constant.float 9.9999997473787516E-6 + %int1_361 = torch.constant.int 1 + %572 = torch.aten.add.Scalar %571, %float9.999990e-06_360, %int1_361 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %572, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %573 = torch.aten.rsqrt %572 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %573, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %574 = torch.aten.mul.Tensor %568, %573 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %574, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_362 = torch.constant.int 5 + %575 = torch.prims.convert_element_type %574, %int5_362 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %575, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %576 = torch.aten.mul.Tensor %7, %575 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %576, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_363 = torch.constant.int 5 + %577 = torch.prims.convert_element_type %576, %int5_363 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %577, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_364 = torch.constant.int -2 + %int-1_365 = torch.constant.int -1 + %578 = torch.aten.transpose.int %8, %int-2_364, %int-1_365 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_366 = torch.constant.int 5 + %579 = torch.prims.convert_element_type %578, %int5_366 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_367 = torch.constant.int 4096 + %580 = torch.prim.ListConstruct %342, %int4096_367 : (!torch.int, !torch.int) -> !torch.list + %581 = torch.aten.view %577, %580 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %581, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %582 = torch.aten.mm %581, %579 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %582, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> %int4_368 = torch.constant.int 4 - %592 = torch.aten.mul.int %int4_368, %398 : !torch.int, !torch.int -> !torch.int - %int32_369 = torch.constant.int 32 - %int8_370 = torch.constant.int 8 - %int128_371 = torch.constant.int 128 - %593 = torch.prim.ListConstruct %592, %int32_369, %int8_370, %int128_371 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %594 = torch.aten.view %591, %593 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %594, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_372 = torch.constant.int 4 - %595 = torch.aten.mul.int %int4_372, %398 : !torch.int, !torch.int -> !torch.int - %596 = torch.prim.ListConstruct %595 : (!torch.int) -> !torch.list - %597 = torch.aten.view %589, %596 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %597, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_373 = torch.constant.int 32 - %int2_374 = torch.constant.int 2 - %int32_375 = torch.constant.int 32 - %int8_376 = torch.constant.int 8 - %int128_377 = torch.constant.int 128 - %598 = torch.prim.ListConstruct %389, %int32_373, %int2_374, %int32_375, %int8_376, %int128_377 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %599 = torch.aten.view %431, %598 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %599, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_378 = torch.constant.int 32 - %600 = torch.aten.mul.int %389, %int32_378 : !torch.int, !torch.int -> !torch.int - %int2_379 = torch.constant.int 2 - %601 = torch.aten.mul.int %600, %int2_379 : !torch.int, !torch.int -> !torch.int - %int32_380 = torch.constant.int 32 - %int8_381 = torch.constant.int 8 - %int128_382 = torch.constant.int 128 - %602 = torch.prim.ListConstruct %601, %int32_380, %int8_381, %int128_382 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %603 = torch.aten.view %599, %602 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %603, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %604 = torch.prim.ListConstruct %597 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_383 = torch.constant.bool false - %605 = torch.aten.index_put %603, %604, %594, %false_383 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %605, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_384 = torch.constant.int 32 - %int2_385 = torch.constant.int 2 - %int32_386 = torch.constant.int 32 - %int8_387 = torch.constant.int 8 - %int128_388 = torch.constant.int 128 - %606 = torch.prim.ListConstruct %389, %int32_384, %int2_385, %int32_386, %int8_387, %int128_388 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %607 = torch.aten.view %605, %606 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %607, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_389 = torch.constant.int 2097152 - %608 = torch.prim.ListConstruct %389, %int2097152_389 : (!torch.int, !torch.int) -> !torch.list - %609 = torch.aten.view %607, %608 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %609, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_390 = torch.constant.int 32 - %int2_391 = torch.constant.int 2 - %int32_392 = torch.constant.int 32 - %int8_393 = torch.constant.int 8 - %int128_394 = torch.constant.int 128 - %610 = torch.prim.ListConstruct %389, %int32_390, %int2_391, %int32_392, %int8_393, %int128_394 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %611 = torch.aten.view %609, %610 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %611, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_395 = torch.constant.int 32 - %int8_396 = torch.constant.int 8 - %int128_397 = torch.constant.int 128 - %612 = torch.prim.ListConstruct %601, %int32_395, %int8_396, %int128_397 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %613 = torch.aten.view %611, %612 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %613, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_398 = torch.constant.int 4 - %int32_399 = torch.constant.int 32 - %int8_400 = torch.constant.int 8 - %int128_401 = torch.constant.int 128 - %614 = torch.prim.ListConstruct %int4_398, %398, %int32_399, %int8_400, %int128_401 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %615 = torch.aten.view %531, %614 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %615, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_402 = torch.constant.int 4 - %616 = torch.aten.mul.int %int4_402, %398 : !torch.int, !torch.int -> !torch.int - %int32_403 = torch.constant.int 32 - %int8_404 = torch.constant.int 8 - %int128_405 = torch.constant.int 128 - %617 = torch.prim.ListConstruct %616, %int32_403, %int8_404, %int128_405 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %618 = torch.aten.view %615, %617 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %618, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_406 = torch.constant.int 1 - %int1_407 = torch.constant.int 1 - %619 = torch.aten.add.Scalar %589, %int1_406, %int1_407 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %619, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_408 = torch.constant.int 4 - %620 = torch.aten.mul.int %int4_408, %398 : !torch.int, !torch.int -> !torch.int - %621 = torch.prim.ListConstruct %620 : (!torch.int) -> !torch.list - %622 = torch.aten.view %619, %621 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %622, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %623 = torch.prim.ListConstruct %622 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_409 = torch.constant.bool false - %624 = torch.aten.index_put %613, %623, %618, %false_409 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %624, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int14336 = torch.constant.int 14336 + %583 = torch.prim.ListConstruct %int4_368, %298, %int14336 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %584 = torch.aten.view %582, %583 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %584, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %585 = torch.aten.silu %584 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %585, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_369 = torch.constant.int -2 + %int-1_370 = torch.constant.int -1 + %586 = torch.aten.transpose.int %9, %int-2_369, %int-1_370 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_371 = torch.constant.int 5 + %587 = torch.prims.convert_element_type %586, %int5_371 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_372 = torch.constant.int 4096 + %588 = torch.prim.ListConstruct %342, %int4096_372 : (!torch.int, !torch.int) -> !torch.list + %589 = torch.aten.view %577, %588 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %589, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %590 = torch.aten.mm %589, %587 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %590, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_373 = torch.constant.int 4 + %int14336_374 = torch.constant.int 14336 + %591 = torch.prim.ListConstruct %int4_373, %298, %int14336_374 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %592 = torch.aten.view %590, %591 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %592, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %593 = torch.aten.mul.Tensor %585, %592 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %593, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_375 = torch.constant.int -2 + %int-1_376 = torch.constant.int -1 + %594 = torch.aten.transpose.int %10, %int-2_375, %int-1_376 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_377 = torch.constant.int 5 + %595 = torch.prims.convert_element_type %594, %int5_377 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_378 = torch.constant.int 14336 + %596 = torch.prim.ListConstruct %342, %int14336_378 : (!torch.int, !torch.int) -> !torch.list + %597 = torch.aten.view %593, %596 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %597, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %598 = torch.aten.mm %597, %595 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %598, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_379 = torch.constant.int 4 + %int4096_380 = torch.constant.int 4096 + %599 = torch.prim.ListConstruct %int4_379, %298, %int4096_380 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %600 = torch.aten.view %598, %599 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %600, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_381 = torch.constant.int 1 + %601 = torch.aten.add.Tensor %567, %600, %int1_381 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %601, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_382 = torch.constant.int 6 + %602 = torch.prims.convert_element_type %601, %int6_382 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %602, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_383 = torch.constant.int 2 + %603 = torch.aten.pow.Tensor_Scalar %602, %int2_383 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %603, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_384 = torch.constant.int -1 + %604 = torch.prim.ListConstruct %int-1_384 : (!torch.int) -> !torch.list + %true_385 = torch.constant.bool true + %none_386 = torch.constant.none + %605 = torch.aten.mean.dim %603, %604, %true_385, %none_386 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %605, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_387 = torch.constant.float 9.9999997473787516E-6 + %int1_388 = torch.constant.int 1 + %606 = torch.aten.add.Scalar %605, %float9.999990e-06_387, %int1_388 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %606, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %607 = torch.aten.rsqrt %606 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %607, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %608 = torch.aten.mul.Tensor %602, %607 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %608, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_389 = torch.constant.int 5 + %609 = torch.prims.convert_element_type %608, %int5_389 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %609, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %610 = torch.aten.mul.Tensor %11, %609 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %610, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_390 = torch.constant.int 5 + %611 = torch.prims.convert_element_type %610, %int5_390 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %611, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_391 = torch.constant.int -2 + %int-1_392 = torch.constant.int -1 + %612 = torch.aten.transpose.int %12, %int-2_391, %int-1_392 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_393 = torch.constant.int 5 + %613 = torch.prims.convert_element_type %612, %int5_393 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_394 = torch.constant.int 4096 + %614 = torch.prim.ListConstruct %342, %int4096_394 : (!torch.int, !torch.int) -> !torch.list + %615 = torch.aten.view %611, %614 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %615, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %616 = torch.aten.mm %615, %613 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %616, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_395 = torch.constant.int 4 + %int4096_396 = torch.constant.int 4096 + %617 = torch.prim.ListConstruct %int4_395, %298, %int4096_396 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %618 = torch.aten.view %616, %617 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %618, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_397 = torch.constant.int -2 + %int-1_398 = torch.constant.int -1 + %619 = torch.aten.transpose.int %13, %int-2_397, %int-1_398 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_399 = torch.constant.int 5 + %620 = torch.prims.convert_element_type %619, %int5_399 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_400 = torch.constant.int 4096 + %621 = torch.prim.ListConstruct %342, %int4096_400 : (!torch.int, !torch.int) -> !torch.list + %622 = torch.aten.view %611, %621 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %622, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %623 = torch.aten.mm %622, %620 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %623, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_401 = torch.constant.int 4 + %int1024_402 = torch.constant.int 1024 + %624 = torch.prim.ListConstruct %int4_401, %298, %int1024_402 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %625 = torch.aten.view %623, %624 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %625, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_403 = torch.constant.int -2 + %int-1_404 = torch.constant.int -1 + %626 = torch.aten.transpose.int %14, %int-2_403, %int-1_404 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_405 = torch.constant.int 5 + %627 = torch.prims.convert_element_type %626, %int5_405 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_406 = torch.constant.int 4096 + %628 = torch.prim.ListConstruct %342, %int4096_406 : (!torch.int, !torch.int) -> !torch.list + %629 = torch.aten.view %611, %628 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %629, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %630 = torch.aten.mm %629, %627 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %630, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_407 = torch.constant.int 4 + %int1024_408 = torch.constant.int 1024 + %631 = torch.prim.ListConstruct %int4_407, %298, %int1024_408 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %632 = torch.aten.view %630, %631 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %632, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_409 = torch.constant.int 4 %int32_410 = torch.constant.int 32 - %int2_411 = torch.constant.int 2 - %int32_412 = torch.constant.int 32 + %int128_411 = torch.constant.int 128 + %633 = torch.prim.ListConstruct %int4_409, %298, %int32_410, %int128_411 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %634 = torch.aten.view %618, %633 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %634, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_412 = torch.constant.int 4 %int8_413 = torch.constant.int 8 %int128_414 = torch.constant.int 128 - %625 = torch.prim.ListConstruct %389, %int32_410, %int2_411, %int32_412, %int8_413, %int128_414 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %626 = torch.aten.view %624, %625 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %626, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_415 = torch.constant.int 2097152 - %627 = torch.prim.ListConstruct %389, %int2097152_415 : (!torch.int, !torch.int) -> !torch.list - %628 = torch.aten.view %626, %627 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %628, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_416 = torch.constant.int -2 - %629 = torch.aten.unsqueeze %587, %int-2_416 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %629, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_417 = torch.constant.int 4 - %int8_418 = torch.constant.int 8 - %int4_419 = torch.constant.int 4 - %int128_420 = torch.constant.int 128 - %630 = torch.prim.ListConstruct %int4_417, %572, %int8_418, %int4_419, %int128_420 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_421 = torch.constant.bool false - %631 = torch.aten.expand %629, %630, %false_421 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %631, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_422 = torch.constant.int 0 - %632 = torch.aten.clone %631, %int0_422 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %632, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_423 = torch.constant.int 4 - %int32_424 = torch.constant.int 32 - %int128_425 = torch.constant.int 128 - %633 = torch.prim.ListConstruct %int4_423, %572, %int32_424, %int128_425 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %634 = torch.aten._unsafe_view %632, %633 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %634, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_426 = torch.constant.int -2 - %635 = torch.aten.unsqueeze %531, %int-2_426 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %635, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_427 = torch.constant.int 1 - %636 = torch.aten.size.int %525, %int1_427 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_428 = torch.constant.int 4 - %int8_429 = torch.constant.int 8 - %int4_430 = torch.constant.int 4 + %635 = torch.prim.ListConstruct %int4_412, %298, %int8_413, %int128_414 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %636 = torch.aten.view %625, %635 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %636, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_415 = torch.constant.int 4 + %int8_416 = torch.constant.int 8 + %int128_417 = torch.constant.int 128 + %637 = torch.prim.ListConstruct %int4_415, %298, %int8_416, %int128_417 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %638 = torch.aten.view %632, %637 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %638, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_418 = torch.constant.int 131072 + %none_419 = torch.constant.none + %none_420 = torch.constant.none + %cpu_421 = torch.constant.device "cpu" + %false_422 = torch.constant.bool false + %639 = torch.aten.arange %int131072_418, %none_419, %none_420, %cpu_421, %false_422 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_423 = torch.constant.int 0 + %int128_424 = torch.constant.int 128 + %int2_425 = torch.constant.int 2 + %int4_426 = torch.constant.int 4 + %none_427 = torch.constant.none + %cpu_428 = torch.constant.device "cpu" + %false_429 = torch.constant.bool false + %640 = torch.aten.arange.start_step %int0_423, %int128_424, %int2_425, %int4_426, %none_427, %cpu_428, %false_429 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_430 = torch.constant.int 6 + %641 = torch.prims.convert_element_type %640, %int6_430 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> %int128_431 = torch.constant.int 128 - %637 = torch.prim.ListConstruct %int4_428, %636, %int8_429, %int4_430, %int128_431 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_432 = torch.constant.bool false - %638 = torch.aten.expand %635, %637, %false_432 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %638, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_433 = torch.constant.int 0 - %639 = torch.aten.clone %638, %int0_433 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %639, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_434 = torch.constant.int 4 - %int32_435 = torch.constant.int 32 - %int128_436 = torch.constant.int 128 - %640 = torch.prim.ListConstruct %int4_434, %636, %int32_435, %int128_436 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %641 = torch.aten._unsafe_view %639, %640 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %641, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_437 = torch.constant.int 1 - %int2_438 = torch.constant.int 2 - %642 = torch.aten.transpose.int %559, %int1_437, %int2_438 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %642, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %642 = torch.aten.div.Scalar %641, %int128_431 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_432 = torch.constant.float 5.000000e+05 + %643 = torch.aten.pow.Scalar %float5.000000e05_432, %642 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %644 = torch.aten.reciprocal %643 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_433 = torch.constant.float 1.000000e+00 + %645 = torch.aten.mul.Scalar %644, %float1.000000e00_433 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %646 = torch.aten.reciprocal %645 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_434 = torch.constant.float 6.2831853071795862 + %647 = torch.aten.mul.Scalar %646, %float6.283190e00_434 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_435 = torch.constant.float 8.192000e+03 + %648 = torch.aten.gt.Scalar %647, %float8.192000e03_435 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_436 = torch.constant.int 8 + %649 = torch.aten.div.Scalar %645, %int8_436 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %650 = torch.aten.where.self %648, %649, %645 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %651 = torch.aten.reciprocal %647 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_437 = torch.constant.int 8192 + %652 = torch.aten.mul.Scalar %651, %int8192_437 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_438 = torch.constant.int 1 %int1_439 = torch.constant.int 1 - %int2_440 = torch.constant.int 2 - %643 = torch.aten.transpose.int %634, %int1_439, %int2_440 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %643, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %653 = torch.aten.sub.Scalar %652, %int1_438, %int1_439 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_440 = torch.constant.int 3 + %654 = torch.aten.div.Scalar %653, %int3_440 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_441 = torch.constant.int 1 - %int2_442 = torch.constant.int 2 - %644 = torch.aten.transpose.int %641, %int1_441, %int2_442 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %644, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_443 = torch.constant.float 0.000000e+00 - %true_444 = torch.constant.bool true - %none_445 = torch.constant.none - %none_446 = torch.constant.none - %645:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%642, %643, %644, %float0.000000e00_443, %true_444, %none_445, %none_446) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %645#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_447 = torch.constant.int 1 - %int2_448 = torch.constant.int 2 - %646 = torch.aten.transpose.int %645#0, %int1_447, %int2_448 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %646, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_449 = torch.constant.int 4 - %int4096_450 = torch.constant.int 4096 - %647 = torch.prim.ListConstruct %int4_449, %544, %int4096_450 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %648 = torch.aten.view %646, %647 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %648, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_451 = torch.constant.int -2 - %int-1_452 = torch.constant.int -1 - %649 = torch.aten.transpose.int %14, %int-2_451, %int-1_452 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_453 = torch.constant.int 4 - %650 = torch.aten.mul.int %int4_453, %544 : !torch.int, !torch.int -> !torch.int - %int4096_454 = torch.constant.int 4096 - %651 = torch.prim.ListConstruct %650, %int4096_454 : (!torch.int, !torch.int) -> !torch.list - %652 = torch.aten.view %648, %651 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %652, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %653 = torch.aten.mm %652, %649 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %653, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_455 = torch.constant.int 4 - %int4096_456 = torch.constant.int 4096 - %654 = torch.prim.ListConstruct %int4_455, %544, %int4096_456 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %655 = torch.aten.view %653, %654 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %655, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_442 = torch.constant.int 1 + %655 = torch.aten.rsub.Scalar %654, %int1_441, %int1_442 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %656 = torch.aten.mul.Tensor %655, %650 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_443 = torch.constant.int 8 + %657 = torch.aten.div.Scalar %656, %int8_443 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %658 = torch.aten.mul.Tensor %654, %650 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_444 = torch.constant.int 1 + %659 = torch.aten.add.Tensor %657, %658, %int1_444 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_445 = torch.constant.float 2.048000e+03 + %660 = torch.aten.lt.Scalar %647, %float2.048000e03_445 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %661 = torch.aten.bitwise_not %660 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_446 = torch.constant.float 8.192000e+03 + %662 = torch.aten.gt.Scalar %647, %float8.192000e03_446 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %663 = torch.aten.bitwise_not %662 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %664 = torch.aten.mul.Tensor %661, %663 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %665 = torch.aten.where.self %664, %659, %650 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %666 = torch.prim.ListConstruct %665, %665 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_447 = torch.constant.int -1 + %667 = torch.aten.cat %666, %int-1_447 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_448 = torch.constant.int 6 + %668 = torch.prims.convert_element_type %667, %int6_448 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_449 = torch.constant.int 1 + %669 = torch.aten.unsqueeze %639, %int1_449 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_450 = torch.constant.int 6 + %670 = torch.prims.convert_element_type %669, %int6_450 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_451 = torch.constant.int 0 + %671 = torch.aten.unsqueeze %668, %int0_451 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_452 = torch.constant.int 6 + %672 = torch.prims.convert_element_type %671, %int6_452 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %673 = torch.aten.mul.Tensor %670, %672 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %674 = torch.aten.cos %673 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_453 = torch.constant.int 5 + %675 = torch.prims.convert_element_type %674, %int5_453 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %676 = torch.aten.sin %673 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_454 = torch.constant.int 5 + %677 = torch.prims.convert_element_type %676, %int5_454 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_455 = torch.constant.int 0 + %int0_456 = torch.constant.int 0 %int1_457 = torch.constant.int 1 - %656 = torch.aten.add.Tensor %494, %655, %int1_457 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %656, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_458 = torch.constant.int 6 - %657 = torch.prims.convert_element_type %656, %int6_458 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %657, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_459 = torch.constant.int 2 - %658 = torch.aten.pow.Tensor_Scalar %657, %int2_459 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %658, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_460 = torch.constant.int -1 - %659 = torch.prim.ListConstruct %int-1_460 : (!torch.int) -> !torch.list - %true_461 = torch.constant.bool true - %none_462 = torch.constant.none - %660 = torch.aten.mean.dim %658, %659, %true_461, %none_462 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %660, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_463 = torch.constant.float 9.9999997473787516E-6 + %678 = torch.aten.slice.Tensor %675, %int0_455, %int0_456, %298, %int1_457 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %678, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_458 = torch.constant.int 1 + %int0_459 = torch.constant.int 0 + %int9223372036854775807_460 = torch.constant.int 9223372036854775807 + %int1_461 = torch.constant.int 1 + %679 = torch.aten.slice.Tensor %678, %int1_458, %int0_459, %int9223372036854775807_460, %int1_461 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %679, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_462 = torch.constant.int 0 + %int0_463 = torch.constant.int 0 %int1_464 = torch.constant.int 1 - %661 = torch.aten.add.Scalar %660, %float9.999990e-06_463, %int1_464 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %661, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %662 = torch.aten.rsqrt %661 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %662, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %663 = torch.aten.mul.Tensor %657, %662 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %663, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_465 = torch.constant.int 5 - %664 = torch.prims.convert_element_type %663, %int5_465 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %664, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %665 = torch.aten.mul.Tensor %15, %664 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %665, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_466 = torch.constant.int 5 - %666 = torch.prims.convert_element_type %665, %int5_466 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %666, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_467 = torch.constant.int -2 - %int-1_468 = torch.constant.int -1 - %667 = torch.aten.transpose.int %16, %int-2_467, %int-1_468 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_469 = torch.constant.int 4 - %668 = torch.aten.mul.int %int4_469, %306 : !torch.int, !torch.int -> !torch.int - %int4096_470 = torch.constant.int 4096 - %669 = torch.prim.ListConstruct %668, %int4096_470 : (!torch.int, !torch.int) -> !torch.list - %670 = torch.aten.view %666, %669 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %670, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %671 = torch.aten.mm %670, %667 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %671, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_471 = torch.constant.int 4 - %int14336_472 = torch.constant.int 14336 - %672 = torch.prim.ListConstruct %int4_471, %306, %int14336_472 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %673 = torch.aten.view %671, %672 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %673, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %674 = torch.aten.silu %673 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %674, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_473 = torch.constant.int -2 - %int-1_474 = torch.constant.int -1 - %675 = torch.aten.transpose.int %17, %int-2_473, %int-1_474 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_475 = torch.constant.int 4 - %676 = torch.aten.mul.int %int4_475, %306 : !torch.int, !torch.int -> !torch.int - %int4096_476 = torch.constant.int 4096 - %677 = torch.prim.ListConstruct %676, %int4096_476 : (!torch.int, !torch.int) -> !torch.list - %678 = torch.aten.view %666, %677 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %678, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %679 = torch.aten.mm %678, %675 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %679, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_477 = torch.constant.int 4 - %int14336_478 = torch.constant.int 14336 - %680 = torch.prim.ListConstruct %int4_477, %306, %int14336_478 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %681 = torch.aten.view %679, %680 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %681, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %682 = torch.aten.mul.Tensor %674, %681 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %682, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_479 = torch.constant.int -2 - %int-1_480 = torch.constant.int -1 - %683 = torch.aten.transpose.int %18, %int-2_479, %int-1_480 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %680 = torch.aten.slice.Tensor %677, %int0_462, %int0_463, %298, %int1_464 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %680, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_465 = torch.constant.int 1 + %int0_466 = torch.constant.int 0 + %int9223372036854775807_467 = torch.constant.int 9223372036854775807 + %int1_468 = torch.constant.int 1 + %681 = torch.aten.slice.Tensor %680, %int1_465, %int0_466, %int9223372036854775807_467, %int1_468 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %681, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_469 = torch.constant.int 0 + %682 = torch.aten.unsqueeze %679, %int0_469 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %682, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_470 = torch.constant.int 1 + %int0_471 = torch.constant.int 0 + %int9223372036854775807_472 = torch.constant.int 9223372036854775807 + %int1_473 = torch.constant.int 1 + %683 = torch.aten.slice.Tensor %682, %int1_470, %int0_471, %int9223372036854775807_472, %int1_473 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %683, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_474 = torch.constant.int 2 + %684 = torch.aten.unsqueeze %683, %int2_474 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %684, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_475 = torch.constant.int 3 + %int0_476 = torch.constant.int 0 + %int9223372036854775807_477 = torch.constant.int 9223372036854775807 + %int1_478 = torch.constant.int 1 + %685 = torch.aten.slice.Tensor %684, %int3_475, %int0_476, %int9223372036854775807_477, %int1_478 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %685, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_479 = torch.constant.int 4 + %int1_480 = torch.constant.int 1 %int1_481 = torch.constant.int 1 - %684 = torch.aten.size.int %673, %int1_481 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_482 = torch.constant.int 4 - %685 = torch.aten.mul.int %int4_482, %684 : !torch.int, !torch.int -> !torch.int - %int14336_483 = torch.constant.int 14336 - %686 = torch.prim.ListConstruct %685, %int14336_483 : (!torch.int, !torch.int) -> !torch.list - %687 = torch.aten.view %682, %686 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %687, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %688 = torch.aten.mm %687, %683 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %688, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_484 = torch.constant.int 4 - %int4096_485 = torch.constant.int 4096 - %689 = torch.prim.ListConstruct %int4_484, %684, %int4096_485 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %690 = torch.aten.view %688, %689 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %690, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_486 = torch.constant.int 1 - %691 = torch.aten.add.Tensor %656, %690, %int1_486 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %691, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_487 = torch.constant.int 6 - %692 = torch.prims.convert_element_type %691, %int6_487 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %692, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int1_482 = torch.constant.int 1 + %686 = torch.prim.ListConstruct %int4_479, %int1_480, %int1_481, %int1_482 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %687 = torch.aten.repeat %685, %686 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %687, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_483 = torch.constant.int 0 + %688 = torch.aten.unsqueeze %681, %int0_483 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %688, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_484 = torch.constant.int 1 + %int0_485 = torch.constant.int 0 + %int9223372036854775807_486 = torch.constant.int 9223372036854775807 + %int1_487 = torch.constant.int 1 + %689 = torch.aten.slice.Tensor %688, %int1_484, %int0_485, %int9223372036854775807_486, %int1_487 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %689, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int2_488 = torch.constant.int 2 - %693 = torch.aten.pow.Tensor_Scalar %692, %int2_488 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %693, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_489 = torch.constant.int -1 - %694 = torch.prim.ListConstruct %int-1_489 : (!torch.int) -> !torch.list - %true_490 = torch.constant.bool true - %none_491 = torch.constant.none - %695 = torch.aten.mean.dim %693, %694, %true_490, %none_491 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %695, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_492 = torch.constant.float 9.9999997473787516E-6 - %int1_493 = torch.constant.int 1 - %696 = torch.aten.add.Scalar %695, %float9.999990e-06_492, %int1_493 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %696, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %697 = torch.aten.rsqrt %696 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %697, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %698 = torch.aten.mul.Tensor %692, %697 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %698, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_494 = torch.constant.int 5 - %699 = torch.prims.convert_element_type %698, %int5_494 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %699, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %700 = torch.aten.mul.Tensor %19, %699 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %700, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_495 = torch.constant.int 5 - %701 = torch.prims.convert_element_type %700, %int5_495 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %701, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_496 = torch.constant.int -2 - %int-1_497 = torch.constant.int -1 - %702 = torch.aten.transpose.int %20, %int-2_496, %int-1_497 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_498 = torch.constant.int 4 - %703 = torch.aten.mul.int %int4_498, %306 : !torch.int, !torch.int -> !torch.int - %int4096_499 = torch.constant.int 4096 - %704 = torch.prim.ListConstruct %703, %int4096_499 : (!torch.int, !torch.int) -> !torch.list - %705 = torch.aten.view %701, %704 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %705, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %706 = torch.aten.mm %705, %702 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %706, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_500 = torch.constant.int 4 - %int4096_501 = torch.constant.int 4096 - %707 = torch.prim.ListConstruct %int4_500, %306, %int4096_501 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %708 = torch.aten.view %706, %707 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %708, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_502 = torch.constant.int -2 - %int-1_503 = torch.constant.int -1 - %709 = torch.aten.transpose.int %21, %int-2_502, %int-1_503 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_504 = torch.constant.int 4 - %710 = torch.aten.mul.int %int4_504, %306 : !torch.int, !torch.int -> !torch.int - %int4096_505 = torch.constant.int 4096 - %711 = torch.prim.ListConstruct %710, %int4096_505 : (!torch.int, !torch.int) -> !torch.list - %712 = torch.aten.view %701, %711 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %712, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %713 = torch.aten.mm %712, %709 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %713, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_506 = torch.constant.int 4 - %int1024_507 = torch.constant.int 1024 - %714 = torch.prim.ListConstruct %int4_506, %306, %int1024_507 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %715 = torch.aten.view %713, %714 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %715, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_508 = torch.constant.int -2 - %int-1_509 = torch.constant.int -1 - %716 = torch.aten.transpose.int %22, %int-2_508, %int-1_509 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_510 = torch.constant.int 4 - %717 = torch.aten.mul.int %int4_510, %306 : !torch.int, !torch.int -> !torch.int - %int4096_511 = torch.constant.int 4096 - %718 = torch.prim.ListConstruct %717, %int4096_511 : (!torch.int, !torch.int) -> !torch.list - %719 = torch.aten.view %701, %718 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %719, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %720 = torch.aten.mm %719, %716 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %720, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_512 = torch.constant.int 4 - %int1024_513 = torch.constant.int 1024 - %721 = torch.prim.ListConstruct %int4_512, %306, %int1024_513 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %722 = torch.aten.view %720, %721 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %722, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_514 = torch.constant.int 4 - %int32_515 = torch.constant.int 32 - %int128_516 = torch.constant.int 128 - %723 = torch.prim.ListConstruct %int4_514, %306, %int32_515, %int128_516 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %724 = torch.aten.view %708, %723 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %724, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_517 = torch.constant.int 4 - %int8_518 = torch.constant.int 8 - %int128_519 = torch.constant.int 128 - %725 = torch.prim.ListConstruct %int4_517, %306, %int8_518, %int128_519 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %726 = torch.aten.view %715, %725 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %726, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_520 = torch.constant.int 4 - %int8_521 = torch.constant.int 8 - %int128_522 = torch.constant.int 128 - %727 = torch.prim.ListConstruct %int4_520, %306, %int8_521, %int128_522 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %728 = torch.aten.view %722, %727 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %728, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_523 = torch.constant.int 131072 - %none_524 = torch.constant.none - %none_525 = torch.constant.none - %cpu_526 = torch.constant.device "cpu" - %false_527 = torch.constant.bool false - %729 = torch.aten.arange %int131072_523, %none_524, %none_525, %cpu_526, %false_527 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_528 = torch.constant.int 0 - %int128_529 = torch.constant.int 128 - %none_530 = torch.constant.none - %none_531 = torch.constant.none - %cpu_532 = torch.constant.device "cpu" - %false_533 = torch.constant.bool false - %730 = torch.aten.arange.start %int0_528, %int128_529, %none_530, %none_531, %cpu_532, %false_533 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_534 = torch.constant.int 2 - %731 = torch.aten.floor_divide.Scalar %730, %int2_534 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_535 = torch.constant.int 6 - %732 = torch.prims.convert_element_type %731, %int6_535 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_536 = torch.constant.int 128 - %733 = torch.aten.div.Scalar %732, %int128_536 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_537 = torch.constant.float 2.000000e+00 - %734 = torch.aten.mul.Scalar %733, %float2.000000e00_537 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_538 = torch.constant.float 5.000000e+05 - %735 = torch.aten.pow.Scalar %float5.000000e05_538, %734 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %736 = torch.aten.reciprocal %735 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_539 = torch.constant.float 1.000000e+00 - %737 = torch.aten.mul.Scalar %736, %float1.000000e00_539 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_540 = torch.constant.int 1 - %738 = torch.aten.unsqueeze %729, %int1_540 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_541 = torch.constant.int 0 - %739 = torch.aten.unsqueeze %737, %int0_541 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %740 = torch.aten.mul.Tensor %738, %739 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_542 = torch.constant.int 1 - %741 = torch.aten.size.int %708, %int1_542 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_543 = torch.constant.int 0 - %742 = torch.aten.add.int %int0_543, %741 : !torch.int, !torch.int -> !torch.int + %690 = torch.aten.unsqueeze %689, %int2_488 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %690, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_489 = torch.constant.int 3 + %int0_490 = torch.constant.int 0 + %int9223372036854775807_491 = torch.constant.int 9223372036854775807 + %int1_492 = torch.constant.int 1 + %691 = torch.aten.slice.Tensor %690, %int3_489, %int0_490, %int9223372036854775807_491, %int1_492 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %691, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_493 = torch.constant.int 4 + %int1_494 = torch.constant.int 1 + %int1_495 = torch.constant.int 1 + %int1_496 = torch.constant.int 1 + %692 = torch.prim.ListConstruct %int4_493, %int1_494, %int1_495, %int1_496 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %693 = torch.aten.repeat %691, %692 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %693, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %694 = torch.aten.mul.Tensor %634, %687 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %694, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_497 = torch.constant.int 3 + %int0_498 = torch.constant.int 0 + %int64_499 = torch.constant.int 64 + %int1_500 = torch.constant.int 1 + %695 = torch.aten.slice.Tensor %634, %int3_497, %int0_498, %int64_499, %int1_500 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %695, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_501 = torch.constant.int 3 + %int64_502 = torch.constant.int 64 + %int9223372036854775807_503 = torch.constant.int 9223372036854775807 + %int1_504 = torch.constant.int 1 + %696 = torch.aten.slice.Tensor %634, %int3_501, %int64_502, %int9223372036854775807_503, %int1_504 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %696, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %697 = torch.aten.neg %696 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %697, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %698 = torch.prim.ListConstruct %697, %695 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_505 = torch.constant.int -1 + %699 = torch.aten.cat %698, %int-1_505 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %699, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %700 = torch.aten.mul.Tensor %699, %693 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %700, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_506 = torch.constant.int 1 + %701 = torch.aten.add.Tensor %694, %700, %int1_506 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %701, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_507 = torch.constant.int 131072 + %none_508 = torch.constant.none + %none_509 = torch.constant.none + %cpu_510 = torch.constant.device "cpu" + %false_511 = torch.constant.bool false + %702 = torch.aten.arange %int131072_507, %none_508, %none_509, %cpu_510, %false_511 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_512 = torch.constant.int 0 + %int128_513 = torch.constant.int 128 + %int2_514 = torch.constant.int 2 + %int4_515 = torch.constant.int 4 + %none_516 = torch.constant.none + %cpu_517 = torch.constant.device "cpu" + %false_518 = torch.constant.bool false + %703 = torch.aten.arange.start_step %int0_512, %int128_513, %int2_514, %int4_515, %none_516, %cpu_517, %false_518 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_519 = torch.constant.int 6 + %704 = torch.prims.convert_element_type %703, %int6_519 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_520 = torch.constant.int 128 + %705 = torch.aten.div.Scalar %704, %int128_520 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_521 = torch.constant.float 5.000000e+05 + %706 = torch.aten.pow.Scalar %float5.000000e05_521, %705 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %707 = torch.aten.reciprocal %706 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_522 = torch.constant.float 1.000000e+00 + %708 = torch.aten.mul.Scalar %707, %float1.000000e00_522 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %709 = torch.aten.reciprocal %708 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_523 = torch.constant.float 6.2831853071795862 + %710 = torch.aten.mul.Scalar %709, %float6.283190e00_523 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_524 = torch.constant.float 8.192000e+03 + %711 = torch.aten.gt.Scalar %710, %float8.192000e03_524 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_525 = torch.constant.int 8 + %712 = torch.aten.div.Scalar %708, %int8_525 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %713 = torch.aten.where.self %711, %712, %708 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %714 = torch.aten.reciprocal %710 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_526 = torch.constant.int 8192 + %715 = torch.aten.mul.Scalar %714, %int8192_526 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_527 = torch.constant.int 1 + %int1_528 = torch.constant.int 1 + %716 = torch.aten.sub.Scalar %715, %int1_527, %int1_528 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_529 = torch.constant.int 3 + %717 = torch.aten.div.Scalar %716, %int3_529 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_530 = torch.constant.int 1 + %int1_531 = torch.constant.int 1 + %718 = torch.aten.rsub.Scalar %717, %int1_530, %int1_531 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %719 = torch.aten.mul.Tensor %718, %713 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_532 = torch.constant.int 8 + %720 = torch.aten.div.Scalar %719, %int8_532 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %721 = torch.aten.mul.Tensor %717, %713 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_533 = torch.constant.int 1 + %722 = torch.aten.add.Tensor %720, %721, %int1_533 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_534 = torch.constant.float 2.048000e+03 + %723 = torch.aten.lt.Scalar %710, %float2.048000e03_534 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %724 = torch.aten.bitwise_not %723 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_535 = torch.constant.float 8.192000e+03 + %725 = torch.aten.gt.Scalar %710, %float8.192000e03_535 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %726 = torch.aten.bitwise_not %725 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %727 = torch.aten.mul.Tensor %724, %726 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %728 = torch.aten.where.self %727, %722, %713 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %729 = torch.prim.ListConstruct %728, %728 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_536 = torch.constant.int -1 + %730 = torch.aten.cat %729, %int-1_536 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_537 = torch.constant.int 6 + %731 = torch.prims.convert_element_type %730, %int6_537 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_538 = torch.constant.int 1 + %732 = torch.aten.unsqueeze %702, %int1_538 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_539 = torch.constant.int 6 + %733 = torch.prims.convert_element_type %732, %int6_539 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_540 = torch.constant.int 0 + %734 = torch.aten.unsqueeze %731, %int0_540 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_541 = torch.constant.int 6 + %735 = torch.prims.convert_element_type %734, %int6_541 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %736 = torch.aten.mul.Tensor %733, %735 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %737 = torch.aten.cos %736 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_542 = torch.constant.int 5 + %738 = torch.prims.convert_element_type %737, %int5_542 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %739 = torch.aten.sin %736 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_543 = torch.constant.int 5 + %740 = torch.prims.convert_element_type %739, %int5_543 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> %int0_544 = torch.constant.int 0 %int0_545 = torch.constant.int 0 %int1_546 = torch.constant.int 1 - %743 = torch.aten.slice.Tensor %740, %int0_544, %int0_545, %742, %int1_546 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %743, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %741 = torch.aten.slice.Tensor %738, %int0_544, %int0_545, %298, %int1_546 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %741, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_547 = torch.constant.int 1 %int0_548 = torch.constant.int 0 %int9223372036854775807_549 = torch.constant.int 9223372036854775807 %int1_550 = torch.constant.int 1 - %744 = torch.aten.slice.Tensor %743, %int1_547, %int0_548, %int9223372036854775807_549, %int1_550 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %744, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_551 = torch.constant.int 1 + %742 = torch.aten.slice.Tensor %741, %int1_547, %int0_548, %int9223372036854775807_549, %int1_550 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %742, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_551 = torch.constant.int 0 %int0_552 = torch.constant.int 0 - %int9223372036854775807_553 = torch.constant.int 9223372036854775807 + %int1_553 = torch.constant.int 1 + %743 = torch.aten.slice.Tensor %740, %int0_551, %int0_552, %298, %int1_553 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %743, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_554 = torch.constant.int 1 - %745 = torch.aten.slice.Tensor %744, %int1_551, %int0_552, %int9223372036854775807_553, %int1_554 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %745, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> %int0_555 = torch.constant.int 0 - %746 = torch.aten.unsqueeze %745, %int0_555 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %746, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_556 = torch.constant.int 1 - %int0_557 = torch.constant.int 0 - %int9223372036854775807_558 = torch.constant.int 9223372036854775807 + %int9223372036854775807_556 = torch.constant.int 9223372036854775807 + %int1_557 = torch.constant.int 1 + %744 = torch.aten.slice.Tensor %743, %int1_554, %int0_555, %int9223372036854775807_556, %int1_557 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %744, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_558 = torch.constant.int 0 + %745 = torch.aten.unsqueeze %742, %int0_558 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %745, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int1_559 = torch.constant.int 1 - %747 = torch.aten.slice.Tensor %746, %int1_556, %int0_557, %int9223372036854775807_558, %int1_559 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %747, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_560 = torch.constant.int 2 - %int0_561 = torch.constant.int 0 - %int9223372036854775807_562 = torch.constant.int 9223372036854775807 - %int1_563 = torch.constant.int 1 - %748 = torch.aten.slice.Tensor %747, %int2_560, %int0_561, %int9223372036854775807_562, %int1_563 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %748, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_564 = torch.constant.int 4 - %int1_565 = torch.constant.int 1 - %int1_566 = torch.constant.int 1 - %749 = torch.prim.ListConstruct %int4_564, %int1_565, %int1_566 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %750 = torch.aten.repeat %748, %749 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %750, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_567 = torch.constant.int 6 - %751 = torch.prims.convert_element_type %724, %int6_567 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %751, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %752 = torch_c.to_builtin_tensor %751 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %753 = torch_c.to_builtin_tensor %750 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %754 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%752, %753) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %755 = torch_c.from_builtin_tensor %754 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %755, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_568 = torch.constant.int 5 - %756 = torch.prims.convert_element_type %755, %int5_568 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %756, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_569 = torch.constant.int 131072 - %none_570 = torch.constant.none - %none_571 = torch.constant.none - %cpu_572 = torch.constant.device "cpu" - %false_573 = torch.constant.bool false - %757 = torch.aten.arange %int131072_569, %none_570, %none_571, %cpu_572, %false_573 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_560 = torch.constant.int 0 + %int9223372036854775807_561 = torch.constant.int 9223372036854775807 + %int1_562 = torch.constant.int 1 + %746 = torch.aten.slice.Tensor %745, %int1_559, %int0_560, %int9223372036854775807_561, %int1_562 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %746, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_563 = torch.constant.int 2 + %747 = torch.aten.unsqueeze %746, %int2_563 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %747, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_564 = torch.constant.int 3 + %int0_565 = torch.constant.int 0 + %int9223372036854775807_566 = torch.constant.int 9223372036854775807 + %int1_567 = torch.constant.int 1 + %748 = torch.aten.slice.Tensor %747, %int3_564, %int0_565, %int9223372036854775807_566, %int1_567 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %748, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_568 = torch.constant.int 4 + %int1_569 = torch.constant.int 1 + %int1_570 = torch.constant.int 1 + %int1_571 = torch.constant.int 1 + %749 = torch.prim.ListConstruct %int4_568, %int1_569, %int1_570, %int1_571 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %750 = torch.aten.repeat %748, %749 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %750, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_572 = torch.constant.int 0 + %751 = torch.aten.unsqueeze %744, %int0_572 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %751, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_573 = torch.constant.int 1 %int0_574 = torch.constant.int 0 - %int128_575 = torch.constant.int 128 - %none_576 = torch.constant.none - %none_577 = torch.constant.none - %cpu_578 = torch.constant.device "cpu" - %false_579 = torch.constant.bool false - %758 = torch.aten.arange.start %int0_574, %int128_575, %none_576, %none_577, %cpu_578, %false_579 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_580 = torch.constant.int 2 - %759 = torch.aten.floor_divide.Scalar %758, %int2_580 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_581 = torch.constant.int 6 - %760 = torch.prims.convert_element_type %759, %int6_581 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_582 = torch.constant.int 128 - %761 = torch.aten.div.Scalar %760, %int128_582 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_583 = torch.constant.float 2.000000e+00 - %762 = torch.aten.mul.Scalar %761, %float2.000000e00_583 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_584 = torch.constant.float 5.000000e+05 - %763 = torch.aten.pow.Scalar %float5.000000e05_584, %762 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %764 = torch.aten.reciprocal %763 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_585 = torch.constant.float 1.000000e+00 - %765 = torch.aten.mul.Scalar %764, %float1.000000e00_585 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_586 = torch.constant.int 1 - %766 = torch.aten.unsqueeze %757, %int1_586 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int9223372036854775807_575 = torch.constant.int 9223372036854775807 + %int1_576 = torch.constant.int 1 + %752 = torch.aten.slice.Tensor %751, %int1_573, %int0_574, %int9223372036854775807_575, %int1_576 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %752, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_577 = torch.constant.int 2 + %753 = torch.aten.unsqueeze %752, %int2_577 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %753, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_578 = torch.constant.int 3 + %int0_579 = torch.constant.int 0 + %int9223372036854775807_580 = torch.constant.int 9223372036854775807 + %int1_581 = torch.constant.int 1 + %754 = torch.aten.slice.Tensor %753, %int3_578, %int0_579, %int9223372036854775807_580, %int1_581 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %754, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_582 = torch.constant.int 4 + %int1_583 = torch.constant.int 1 + %int1_584 = torch.constant.int 1 + %int1_585 = torch.constant.int 1 + %755 = torch.prim.ListConstruct %int4_582, %int1_583, %int1_584, %int1_585 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %756 = torch.aten.repeat %754, %755 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %756, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %757 = torch.aten.mul.Tensor %636, %750 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %757, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_586 = torch.constant.int 3 %int0_587 = torch.constant.int 0 - %767 = torch.aten.unsqueeze %765, %int0_587 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %768 = torch.aten.mul.Tensor %766, %767 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_588 = torch.constant.int 1 - %769 = torch.aten.size.int %715, %int1_588 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_589 = torch.constant.int 0 - %770 = torch.aten.add.int %int0_589, %769 : !torch.int, !torch.int -> !torch.int - %int0_590 = torch.constant.int 0 - %int0_591 = torch.constant.int 0 - %int1_592 = torch.constant.int 1 - %771 = torch.aten.slice.Tensor %768, %int0_590, %int0_591, %770, %int1_592 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %771, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %int64_588 = torch.constant.int 64 + %int1_589 = torch.constant.int 1 + %758 = torch.aten.slice.Tensor %636, %int3_586, %int0_587, %int64_588, %int1_589 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %758, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_590 = torch.constant.int 3 + %int64_591 = torch.constant.int 64 + %int9223372036854775807_592 = torch.constant.int 9223372036854775807 %int1_593 = torch.constant.int 1 - %int0_594 = torch.constant.int 0 - %int9223372036854775807_595 = torch.constant.int 9223372036854775807 - %int1_596 = torch.constant.int 1 - %772 = torch.aten.slice.Tensor %771, %int1_593, %int0_594, %int9223372036854775807_595, %int1_596 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %772, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %759 = torch.aten.slice.Tensor %636, %int3_590, %int64_591, %int9223372036854775807_592, %int1_593 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %759, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %760 = torch.aten.neg %759 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %760, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %761 = torch.prim.ListConstruct %760, %758 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_594 = torch.constant.int -1 + %762 = torch.aten.cat %761, %int-1_594 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %762, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %763 = torch.aten.mul.Tensor %762, %756 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %763, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_595 = torch.constant.int 1 + %764 = torch.aten.add.Tensor %757, %763, %int1_595 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %764, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_596 = torch.constant.int 32 + %765 = torch.aten.mul.Scalar %arg2, %int32_596 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %765, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> %int1_597 = torch.constant.int 1 - %int0_598 = torch.constant.int 0 - %int9223372036854775807_599 = torch.constant.int 9223372036854775807 - %int1_600 = torch.constant.int 1 - %773 = torch.aten.slice.Tensor %772, %int1_597, %int0_598, %int9223372036854775807_599, %int1_600 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %773, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_601 = torch.constant.int 0 - %774 = torch.aten.unsqueeze %773, %int0_601 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %774, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_602 = torch.constant.int 1 - %int0_603 = torch.constant.int 0 - %int9223372036854775807_604 = torch.constant.int 9223372036854775807 - %int1_605 = torch.constant.int 1 - %775 = torch.aten.slice.Tensor %774, %int1_602, %int0_603, %int9223372036854775807_604, %int1_605 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %775, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_606 = torch.constant.int 2 - %int0_607 = torch.constant.int 0 - %int9223372036854775807_608 = torch.constant.int 9223372036854775807 + %int1_598 = torch.constant.int 1 + %766 = torch.aten.add.Scalar %765, %int1_597, %int1_598 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %766, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_599 = torch.constant.int 2 + %767 = torch.aten.mul.Scalar %766, %int2_599 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %767, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_600 = torch.constant.int 0 + %int1_601 = torch.constant.int 1 + %768 = torch.aten.add.Scalar %767, %int0_600, %int1_601 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %768, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %769 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %770 = torch.aten.view %768, %769 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %770, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_602 = torch.constant.int 4 + %int32_603 = torch.constant.int 32 + %int8_604 = torch.constant.int 8 + %int128_605 = torch.constant.int 128 + %771 = torch.prim.ListConstruct %int4_602, %296, %int32_603, %int8_604, %int128_605 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %772 = torch.aten.view %764, %771 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %772, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_606 = torch.constant.int 32 + %int8_607 = torch.constant.int 8 + %int128_608 = torch.constant.int 128 + %773 = torch.prim.ListConstruct %504, %int32_606, %int8_607, %int128_608 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %774 = torch.aten.view %772, %773 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %774, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> %int1_609 = torch.constant.int 1 - %776 = torch.aten.slice.Tensor %775, %int2_606, %int0_607, %int9223372036854775807_608, %int1_609 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %776, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_610 = torch.constant.int 4 - %int1_611 = torch.constant.int 1 - %int1_612 = torch.constant.int 1 - %777 = torch.prim.ListConstruct %int4_610, %int1_611, %int1_612 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %778 = torch.aten.repeat %776, %777 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %778, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_613 = torch.constant.int 6 - %779 = torch.prims.convert_element_type %726, %int6_613 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %779, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %780 = torch_c.to_builtin_tensor %779 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %781 = torch_c.to_builtin_tensor %778 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %782 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%780, %781) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %783 = torch_c.from_builtin_tensor %782 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %783, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_614 = torch.constant.int 5 - %784 = torch.prims.convert_element_type %783, %int5_614 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %784, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_615 = torch.constant.int 64 - %785 = torch.aten.mul.Scalar %arg2, %int64_615 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %785, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_616 = torch.constant.int 4 - %int1_617 = torch.constant.int 1 - %786 = torch.aten.add.Scalar %785, %int4_616, %int1_617 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %786, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_618 = torch.constant.int 4 - %int32_619 = torch.constant.int 32 - %int8_620 = torch.constant.int 8 - %int128_621 = torch.constant.int 128 - %787 = torch.prim.ListConstruct %int4_618, %398, %int32_619, %int8_620, %int128_621 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %788 = torch.aten.view %784, %787 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %788, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_622 = torch.constant.int 4 - %789 = torch.aten.mul.int %int4_622, %398 : !torch.int, !torch.int -> !torch.int - %int32_623 = torch.constant.int 32 - %int8_624 = torch.constant.int 8 + %int2_610 = torch.constant.int 2 + %775 = torch.aten.transpose.int %774, %int1_609, %int2_610 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %775, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_611 = torch.constant.int 5 + %776 = torch.prims.convert_element_type %775, %int5_611 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %776, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_612 = torch.constant.int 32 + %int2_613 = torch.constant.int 2 + %int8_614 = torch.constant.int 8 + %int32_615 = torch.constant.int 32 + %int128_616 = torch.constant.int 128 + %777 = torch.prim.ListConstruct %297, %int32_612, %int2_613, %int8_614, %int32_615, %int128_616 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %778 = torch.aten.view %540, %777 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %778, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_617 = torch.constant.int 8 + %int32_618 = torch.constant.int 32 + %int128_619 = torch.constant.int 128 + %779 = torch.prim.ListConstruct %497, %int8_617, %int32_618, %int128_619 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %780 = torch.aten.view %778, %779 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %780, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %781 = torch.prim.ListConstruct %770 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_620 = torch.constant.bool false + %782 = torch.aten.index_put %780, %781, %776, %false_620 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %782, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_621 = torch.constant.int 32 + %int2_622 = torch.constant.int 2 + %int8_623 = torch.constant.int 8 + %int32_624 = torch.constant.int 32 %int128_625 = torch.constant.int 128 - %790 = torch.prim.ListConstruct %789, %int32_623, %int8_624, %int128_625 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %791 = torch.aten.view %788, %790 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %791, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_626 = torch.constant.int 4 - %792 = torch.aten.mul.int %int4_626, %398 : !torch.int, !torch.int -> !torch.int - %793 = torch.prim.ListConstruct %792 : (!torch.int) -> !torch.list - %794 = torch.aten.view %786, %793 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %794, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %783 = torch.prim.ListConstruct %297, %int32_621, %int2_622, %int8_623, %int32_624, %int128_625 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %784 = torch.aten.view %782, %783 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %784, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_626 = torch.constant.int 2097152 + %785 = torch.prim.ListConstruct %297, %int2097152_626 : (!torch.int, !torch.int) -> !torch.list + %786 = torch.aten.view %784, %785 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %786, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> %int32_627 = torch.constant.int 32 %int2_628 = torch.constant.int 2 - %int32_629 = torch.constant.int 32 - %int8_630 = torch.constant.int 8 + %int8_629 = torch.constant.int 8 + %int32_630 = torch.constant.int 32 %int128_631 = torch.constant.int 128 - %795 = torch.prim.ListConstruct %389, %int32_627, %int2_628, %int32_629, %int8_630, %int128_631 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %796 = torch.aten.view %628, %795 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %796, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_632 = torch.constant.int 32 - %797 = torch.aten.mul.int %389, %int32_632 : !torch.int, !torch.int -> !torch.int - %int2_633 = torch.constant.int 2 - %798 = torch.aten.mul.int %797, %int2_633 : !torch.int, !torch.int -> !torch.int - %int32_634 = torch.constant.int 32 - %int8_635 = torch.constant.int 8 - %int128_636 = torch.constant.int 128 - %799 = torch.prim.ListConstruct %798, %int32_634, %int8_635, %int128_636 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %800 = torch.aten.view %796, %799 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %800, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %801 = torch.prim.ListConstruct %794 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_637 = torch.constant.bool false - %802 = torch.aten.index_put %800, %801, %791, %false_637 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %802, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_638 = torch.constant.int 32 - %int2_639 = torch.constant.int 2 - %int32_640 = torch.constant.int 32 - %int8_641 = torch.constant.int 8 - %int128_642 = torch.constant.int 128 - %803 = torch.prim.ListConstruct %389, %int32_638, %int2_639, %int32_640, %int8_641, %int128_642 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %804 = torch.aten.view %802, %803 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %804, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_643 = torch.constant.int 2097152 - %805 = torch.prim.ListConstruct %389, %int2097152_643 : (!torch.int, !torch.int) -> !torch.list - %806 = torch.aten.view %804, %805 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %806, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_644 = torch.constant.int 32 - %int2_645 = torch.constant.int 2 - %int32_646 = torch.constant.int 32 - %int8_647 = torch.constant.int 8 - %int128_648 = torch.constant.int 128 - %807 = torch.prim.ListConstruct %389, %int32_644, %int2_645, %int32_646, %int8_647, %int128_648 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %808 = torch.aten.view %806, %807 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %808, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_649 = torch.constant.int 32 - %int8_650 = torch.constant.int 8 - %int128_651 = torch.constant.int 128 - %809 = torch.prim.ListConstruct %798, %int32_649, %int8_650, %int128_651 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %810 = torch.aten.view %808, %809 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %810, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_652 = torch.constant.int 4 - %int32_653 = torch.constant.int 32 + %787 = torch.prim.ListConstruct %297, %int32_627, %int2_628, %int8_629, %int32_630, %int128_631 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %788 = torch.aten.view %786, %787 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %788, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_632 = torch.constant.int 8 + %int32_633 = torch.constant.int 32 + %int128_634 = torch.constant.int 128 + %789 = torch.prim.ListConstruct %497, %int8_632, %int32_633, %int128_634 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %790 = torch.aten.view %788, %789 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %790, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_635 = torch.constant.int 32 + %791 = torch.aten.mul.Scalar %arg2, %int32_635 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %791, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_636 = torch.constant.int 1 + %int1_637 = torch.constant.int 1 + %792 = torch.aten.add.Scalar %791, %int1_636, %int1_637 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %792, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_638 = torch.constant.int 2 + %793 = torch.aten.mul.Scalar %792, %int2_638 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %793, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_639 = torch.constant.int 1 + %int1_640 = torch.constant.int 1 + %794 = torch.aten.add.Scalar %793, %int1_639, %int1_640 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %794, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %795 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %796 = torch.aten.view %794, %795 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %796, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_641 = torch.constant.int 4 + %int32_642 = torch.constant.int 32 + %int8_643 = torch.constant.int 8 + %int128_644 = torch.constant.int 128 + %797 = torch.prim.ListConstruct %int4_641, %296, %int32_642, %int8_643, %int128_644 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %798 = torch.aten.view %638, %797 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %798, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_645 = torch.constant.int 32 + %int8_646 = torch.constant.int 8 + %int128_647 = torch.constant.int 128 + %799 = torch.prim.ListConstruct %504, %int32_645, %int8_646, %int128_647 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %800 = torch.aten.view %798, %799 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %800, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_648 = torch.constant.int 1 + %int2_649 = torch.constant.int 2 + %801 = torch.aten.transpose.int %800, %int1_648, %int2_649 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %801, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_650 = torch.constant.int 5 + %802 = torch.prims.convert_element_type %801, %int5_650 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %802, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %803 = torch.prim.ListConstruct %796 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_651 = torch.constant.bool false + %804 = torch.aten.index_put %790, %803, %802, %false_651 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %804, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_652 = torch.constant.int 32 + %int2_653 = torch.constant.int 2 %int8_654 = torch.constant.int 8 - %int128_655 = torch.constant.int 128 - %811 = torch.prim.ListConstruct %int4_652, %398, %int32_653, %int8_654, %int128_655 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %812 = torch.aten.view %728, %811 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %812, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_656 = torch.constant.int 4 - %813 = torch.aten.mul.int %int4_656, %398 : !torch.int, !torch.int -> !torch.int - %int32_657 = torch.constant.int 32 - %int8_658 = torch.constant.int 8 - %int128_659 = torch.constant.int 128 - %814 = torch.prim.ListConstruct %813, %int32_657, %int8_658, %int128_659 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %815 = torch.aten.view %812, %814 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %815, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_660 = torch.constant.int 1 - %int1_661 = torch.constant.int 1 - %816 = torch.aten.add.Scalar %786, %int1_660, %int1_661 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %816, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_662 = torch.constant.int 4 - %817 = torch.aten.mul.int %int4_662, %398 : !torch.int, !torch.int -> !torch.int - %818 = torch.prim.ListConstruct %817 : (!torch.int) -> !torch.list - %819 = torch.aten.view %816, %818 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %819, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %820 = torch.prim.ListConstruct %819 : (!torch.vtensor<[?],si64>) -> !torch.list> + %int32_655 = torch.constant.int 32 + %int128_656 = torch.constant.int 128 + %805 = torch.prim.ListConstruct %297, %int32_652, %int2_653, %int8_654, %int32_655, %int128_656 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %806 = torch.aten.view %804, %805 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %806, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_657 = torch.constant.int 2097152 + %807 = torch.prim.ListConstruct %297, %int2097152_657 : (!torch.int, !torch.int) -> !torch.list + %808 = torch.aten.view %806, %807 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %808, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_658 = torch.constant.int -2 + %809 = torch.aten.unsqueeze %764, %int-2_658 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %809, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_659 = torch.constant.int 4 + %int8_660 = torch.constant.int 8 + %int4_661 = torch.constant.int 4 + %int128_662 = torch.constant.int 128 + %810 = torch.prim.ListConstruct %int4_659, %298, %int8_660, %int4_661, %int128_662 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %false_663 = torch.constant.bool false - %821 = torch.aten.index_put %810, %820, %815, %false_663 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %821, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_664 = torch.constant.int 32 - %int2_665 = torch.constant.int 2 + %811 = torch.aten.expand %809, %810, %false_663 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %811, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_664 = torch.constant.int 0 + %812 = torch.aten.clone %811, %int0_664 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %812, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_665 = torch.constant.int 4 %int32_666 = torch.constant.int 32 - %int8_667 = torch.constant.int 8 - %int128_668 = torch.constant.int 128 - %822 = torch.prim.ListConstruct %389, %int32_664, %int2_665, %int32_666, %int8_667, %int128_668 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %823 = torch.aten.view %821, %822 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %823, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_669 = torch.constant.int 2097152 - %824 = torch.prim.ListConstruct %389, %int2097152_669 : (!torch.int, !torch.int) -> !torch.list - %825 = torch.aten.view %823, %824 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %825, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_670 = torch.constant.int -2 - %826 = torch.aten.unsqueeze %784, %int-2_670 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %826, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int128_667 = torch.constant.int 128 + %813 = torch.prim.ListConstruct %int4_665, %298, %int32_666, %int128_667 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %814 = torch.aten._unsafe_view %812, %813 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %814, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_668 = torch.constant.int -2 + %815 = torch.aten.unsqueeze %638, %int-2_668 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %815, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_669 = torch.constant.int 4 + %int8_670 = torch.constant.int 8 %int4_671 = torch.constant.int 4 - %int8_672 = torch.constant.int 8 - %int4_673 = torch.constant.int 4 - %int128_674 = torch.constant.int 128 - %827 = torch.prim.ListConstruct %int4_671, %769, %int8_672, %int4_673, %int128_674 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_675 = torch.constant.bool false - %828 = torch.aten.expand %826, %827, %false_675 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %828, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_676 = torch.constant.int 0 - %829 = torch.aten.clone %828, %int0_676 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %829, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_677 = torch.constant.int 4 - %int32_678 = torch.constant.int 32 - %int128_679 = torch.constant.int 128 - %830 = torch.prim.ListConstruct %int4_677, %769, %int32_678, %int128_679 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %831 = torch.aten._unsafe_view %829, %830 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %831, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_680 = torch.constant.int -2 - %832 = torch.aten.unsqueeze %728, %int-2_680 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %832, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_681 = torch.constant.int 1 - %833 = torch.aten.size.int %722, %int1_681 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_682 = torch.constant.int 4 - %int8_683 = torch.constant.int 8 - %int4_684 = torch.constant.int 4 - %int128_685 = torch.constant.int 128 - %834 = torch.prim.ListConstruct %int4_682, %833, %int8_683, %int4_684, %int128_685 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_686 = torch.constant.bool false - %835 = torch.aten.expand %832, %834, %false_686 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %835, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_687 = torch.constant.int 0 - %836 = torch.aten.clone %835, %int0_687 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %836, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_688 = torch.constant.int 4 - %int32_689 = torch.constant.int 32 - %int128_690 = torch.constant.int 128 - %837 = torch.prim.ListConstruct %int4_688, %833, %int32_689, %int128_690 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %838 = torch.aten._unsafe_view %836, %837 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %838, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_691 = torch.constant.int 1 - %int2_692 = torch.constant.int 2 - %839 = torch.aten.transpose.int %756, %int1_691, %int2_692 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %839, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_693 = torch.constant.int 1 - %int2_694 = torch.constant.int 2 - %840 = torch.aten.transpose.int %831, %int1_693, %int2_694 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %840, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_695 = torch.constant.int 1 - %int2_696 = torch.constant.int 2 - %841 = torch.aten.transpose.int %838, %int1_695, %int2_696 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %841, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_697 = torch.constant.float 0.000000e+00 - %true_698 = torch.constant.bool true - %none_699 = torch.constant.none - %none_700 = torch.constant.none - %842:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%839, %840, %841, %float0.000000e00_697, %true_698, %none_699, %none_700) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %842#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_701 = torch.constant.int 1 - %int2_702 = torch.constant.int 2 - %843 = torch.aten.transpose.int %842#0, %int1_701, %int2_702 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %843, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_703 = torch.constant.int 4 - %int4096_704 = torch.constant.int 4096 - %844 = torch.prim.ListConstruct %int4_703, %741, %int4096_704 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %845 = torch.aten.view %843, %844 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %845, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_705 = torch.constant.int -2 - %int-1_706 = torch.constant.int -1 - %846 = torch.aten.transpose.int %23, %int-2_705, %int-1_706 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_707 = torch.constant.int 4 - %847 = torch.aten.mul.int %int4_707, %741 : !torch.int, !torch.int -> !torch.int - %int4096_708 = torch.constant.int 4096 - %848 = torch.prim.ListConstruct %847, %int4096_708 : (!torch.int, !torch.int) -> !torch.list - %849 = torch.aten.view %845, %848 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %849, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %850 = torch.aten.mm %849, %846 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %850, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_709 = torch.constant.int 4 + %int128_672 = torch.constant.int 128 + %816 = torch.prim.ListConstruct %int4_669, %298, %int8_670, %int4_671, %int128_672 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_673 = torch.constant.bool false + %817 = torch.aten.expand %815, %816, %false_673 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %817, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_674 = torch.constant.int 0 + %818 = torch.aten.clone %817, %int0_674 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %818, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_675 = torch.constant.int 4 + %int32_676 = torch.constant.int 32 + %int128_677 = torch.constant.int 128 + %819 = torch.prim.ListConstruct %int4_675, %298, %int32_676, %int128_677 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %820 = torch.aten._unsafe_view %818, %819 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %820, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_678 = torch.constant.int 1 + %int2_679 = torch.constant.int 2 + %821 = torch.aten.transpose.int %701, %int1_678, %int2_679 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %821, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_680 = torch.constant.int 1 + %int2_681 = torch.constant.int 2 + %822 = torch.aten.transpose.int %814, %int1_680, %int2_681 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %822, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_682 = torch.constant.int 1 + %int2_683 = torch.constant.int 2 + %823 = torch.aten.transpose.int %820, %int1_682, %int2_683 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %823, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_684 = torch.constant.float 0.000000e+00 + %false_685 = torch.constant.bool false + %none_686 = torch.constant.none + %824:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%821, %822, %823, %float0.000000e00_684, %false_685, %327, %none_686) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %824#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_687 = torch.constant.int 1 + %int2_688 = torch.constant.int 2 + %825 = torch.aten.transpose.int %824#0, %int1_687, %int2_688 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %825, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_689 = torch.constant.int 4 + %int4096_690 = torch.constant.int 4096 + %826 = torch.prim.ListConstruct %int4_689, %298, %int4096_690 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %827 = torch.aten.view %825, %826 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %827, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_691 = torch.constant.int -2 + %int-1_692 = torch.constant.int -1 + %828 = torch.aten.transpose.int %15, %int-2_691, %int-1_692 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_693 = torch.constant.int 5 + %829 = torch.prims.convert_element_type %828, %int5_693 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_694 = torch.constant.int 4096 + %830 = torch.prim.ListConstruct %342, %int4096_694 : (!torch.int, !torch.int) -> !torch.list + %831 = torch.aten.view %827, %830 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %831, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %832 = torch.aten.mm %831, %829 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %832, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_695 = torch.constant.int 4 + %int4096_696 = torch.constant.int 4096 + %833 = torch.prim.ListConstruct %int4_695, %298, %int4096_696 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %834 = torch.aten.view %832, %833 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %834, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_697 = torch.constant.int 1 + %835 = torch.aten.add.Tensor %601, %834, %int1_697 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %835, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_698 = torch.constant.int 6 + %836 = torch.prims.convert_element_type %835, %int6_698 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %836, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_699 = torch.constant.int 2 + %837 = torch.aten.pow.Tensor_Scalar %836, %int2_699 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %837, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_700 = torch.constant.int -1 + %838 = torch.prim.ListConstruct %int-1_700 : (!torch.int) -> !torch.list + %true_701 = torch.constant.bool true + %none_702 = torch.constant.none + %839 = torch.aten.mean.dim %837, %838, %true_701, %none_702 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %839, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_703 = torch.constant.float 9.9999997473787516E-6 + %int1_704 = torch.constant.int 1 + %840 = torch.aten.add.Scalar %839, %float9.999990e-06_703, %int1_704 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %840, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %841 = torch.aten.rsqrt %840 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %841, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %842 = torch.aten.mul.Tensor %836, %841 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %842, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_705 = torch.constant.int 5 + %843 = torch.prims.convert_element_type %842, %int5_705 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %843, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %844 = torch.aten.mul.Tensor %16, %843 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %844, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_706 = torch.constant.int 5 + %845 = torch.prims.convert_element_type %844, %int5_706 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %845, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_707 = torch.constant.int -2 + %int-1_708 = torch.constant.int -1 + %846 = torch.aten.transpose.int %17, %int-2_707, %int-1_708 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_709 = torch.constant.int 5 + %847 = torch.prims.convert_element_type %846, %int5_709 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> %int4096_710 = torch.constant.int 4096 - %851 = torch.prim.ListConstruct %int4_709, %741, %int4096_710 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %852 = torch.aten.view %850, %851 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %852, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_711 = torch.constant.int 1 - %853 = torch.aten.add.Tensor %691, %852, %int1_711 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %853, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_712 = torch.constant.int 6 - %854 = torch.prims.convert_element_type %853, %int6_712 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %854, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_713 = torch.constant.int 2 - %855 = torch.aten.pow.Tensor_Scalar %854, %int2_713 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %855, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %848 = torch.prim.ListConstruct %342, %int4096_710 : (!torch.int, !torch.int) -> !torch.list + %849 = torch.aten.view %845, %848 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %849, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %850 = torch.aten.mm %849, %847 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %850, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_711 = torch.constant.int 4 + %int14336_712 = torch.constant.int 14336 + %851 = torch.prim.ListConstruct %int4_711, %298, %int14336_712 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %852 = torch.aten.view %850, %851 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %852, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %853 = torch.aten.silu %852 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %853, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_713 = torch.constant.int -2 %int-1_714 = torch.constant.int -1 - %856 = torch.prim.ListConstruct %int-1_714 : (!torch.int) -> !torch.list - %true_715 = torch.constant.bool true - %none_716 = torch.constant.none - %857 = torch.aten.mean.dim %855, %856, %true_715, %none_716 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %857, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_717 = torch.constant.float 9.9999997473787516E-6 - %int1_718 = torch.constant.int 1 - %858 = torch.aten.add.Scalar %857, %float9.999990e-06_717, %int1_718 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %858, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %859 = torch.aten.rsqrt %858 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %859, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %860 = torch.aten.mul.Tensor %854, %859 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %860, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_719 = torch.constant.int 5 - %861 = torch.prims.convert_element_type %860, %int5_719 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %861, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %862 = torch.aten.mul.Tensor %24, %861 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %862, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_720 = torch.constant.int 5 - %863 = torch.prims.convert_element_type %862, %int5_720 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %863, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_721 = torch.constant.int -2 - %int-1_722 = torch.constant.int -1 - %864 = torch.aten.transpose.int %25, %int-2_721, %int-1_722 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %854 = torch.aten.transpose.int %18, %int-2_713, %int-1_714 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_715 = torch.constant.int 5 + %855 = torch.prims.convert_element_type %854, %int5_715 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_716 = torch.constant.int 4096 + %856 = torch.prim.ListConstruct %342, %int4096_716 : (!torch.int, !torch.int) -> !torch.list + %857 = torch.aten.view %845, %856 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %857, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %858 = torch.aten.mm %857, %855 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %858, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_717 = torch.constant.int 4 + %int14336_718 = torch.constant.int 14336 + %859 = torch.prim.ListConstruct %int4_717, %298, %int14336_718 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %860 = torch.aten.view %858, %859 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %860, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %861 = torch.aten.mul.Tensor %853, %860 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %861, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_719 = torch.constant.int -2 + %int-1_720 = torch.constant.int -1 + %862 = torch.aten.transpose.int %19, %int-2_719, %int-1_720 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_721 = torch.constant.int 5 + %863 = torch.prims.convert_element_type %862, %int5_721 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_722 = torch.constant.int 14336 + %864 = torch.prim.ListConstruct %342, %int14336_722 : (!torch.int, !torch.int) -> !torch.list + %865 = torch.aten.view %861, %864 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %865, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %866 = torch.aten.mm %865, %863 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %866, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> %int4_723 = torch.constant.int 4 - %865 = torch.aten.mul.int %int4_723, %306 : !torch.int, !torch.int -> !torch.int %int4096_724 = torch.constant.int 4096 - %866 = torch.prim.ListConstruct %865, %int4096_724 : (!torch.int, !torch.int) -> !torch.list - %867 = torch.aten.view %863, %866 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %867, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %868 = torch.aten.mm %867, %864 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %868, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_725 = torch.constant.int 4 - %int14336_726 = torch.constant.int 14336 - %869 = torch.prim.ListConstruct %int4_725, %306, %int14336_726 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %870 = torch.aten.view %868, %869 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %870, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %871 = torch.aten.silu %870 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %871, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_727 = torch.constant.int -2 + %867 = torch.prim.ListConstruct %int4_723, %298, %int4096_724 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %868 = torch.aten.view %866, %867 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %868, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_725 = torch.constant.int 1 + %869 = torch.aten.add.Tensor %835, %868, %int1_725 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %869, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_726 = torch.constant.int 6 + %870 = torch.prims.convert_element_type %869, %int6_726 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %870, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_727 = torch.constant.int 2 + %871 = torch.aten.pow.Tensor_Scalar %870, %int2_727 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %871, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> %int-1_728 = torch.constant.int -1 - %872 = torch.aten.transpose.int %26, %int-2_727, %int-1_728 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_729 = torch.constant.int 4 - %873 = torch.aten.mul.int %int4_729, %306 : !torch.int, !torch.int -> !torch.int - %int4096_730 = torch.constant.int 4096 - %874 = torch.prim.ListConstruct %873, %int4096_730 : (!torch.int, !torch.int) -> !torch.list - %875 = torch.aten.view %863, %874 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %875, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %876 = torch.aten.mm %875, %872 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %876, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_731 = torch.constant.int 4 - %int14336_732 = torch.constant.int 14336 - %877 = torch.prim.ListConstruct %int4_731, %306, %int14336_732 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %878 = torch.aten.view %876, %877 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %878, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %879 = torch.aten.mul.Tensor %871, %878 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %879, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_733 = torch.constant.int -2 - %int-1_734 = torch.constant.int -1 - %880 = torch.aten.transpose.int %27, %int-2_733, %int-1_734 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_735 = torch.constant.int 1 - %881 = torch.aten.size.int %870, %int1_735 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_736 = torch.constant.int 4 - %882 = torch.aten.mul.int %int4_736, %881 : !torch.int, !torch.int -> !torch.int - %int14336_737 = torch.constant.int 14336 - %883 = torch.prim.ListConstruct %882, %int14336_737 : (!torch.int, !torch.int) -> !torch.list - %884 = torch.aten.view %879, %883 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %884, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %885 = torch.aten.mm %884, %880 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %885, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_738 = torch.constant.int 4 - %int4096_739 = torch.constant.int 4096 - %886 = torch.prim.ListConstruct %int4_738, %881, %int4096_739 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %887 = torch.aten.view %885, %886 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %887, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_740 = torch.constant.int 1 - %888 = torch.aten.add.Tensor %853, %887, %int1_740 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %888, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_741 = torch.constant.int 6 - %889 = torch.prims.convert_element_type %888, %int6_741 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %889, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_742 = torch.constant.int 2 - %890 = torch.aten.pow.Tensor_Scalar %889, %int2_742 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %890, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_743 = torch.constant.int -1 - %891 = torch.prim.ListConstruct %int-1_743 : (!torch.int) -> !torch.list - %true_744 = torch.constant.bool true - %none_745 = torch.constant.none - %892 = torch.aten.mean.dim %890, %891, %true_744, %none_745 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %892, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_746 = torch.constant.float 9.9999997473787516E-6 - %int1_747 = torch.constant.int 1 - %893 = torch.aten.add.Scalar %892, %float9.999990e-06_746, %int1_747 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %893, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %894 = torch.aten.rsqrt %893 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %894, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %895 = torch.aten.mul.Tensor %889, %894 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %895, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_748 = torch.constant.int 5 - %896 = torch.prims.convert_element_type %895, %int5_748 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %896, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %897 = torch.aten.mul.Tensor %28, %896 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %897, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %872 = torch.prim.ListConstruct %int-1_728 : (!torch.int) -> !torch.list + %true_729 = torch.constant.bool true + %none_730 = torch.constant.none + %873 = torch.aten.mean.dim %871, %872, %true_729, %none_730 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %873, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_731 = torch.constant.float 9.9999997473787516E-6 + %int1_732 = torch.constant.int 1 + %874 = torch.aten.add.Scalar %873, %float9.999990e-06_731, %int1_732 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %874, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %875 = torch.aten.rsqrt %874 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %875, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %876 = torch.aten.mul.Tensor %870, %875 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %876, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_733 = torch.constant.int 5 + %877 = torch.prims.convert_element_type %876, %int5_733 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %877, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %878 = torch.aten.mul.Tensor %20, %877 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %878, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_734 = torch.constant.int 5 + %879 = torch.prims.convert_element_type %878, %int5_734 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %879, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_735 = torch.constant.int -2 + %int-1_736 = torch.constant.int -1 + %880 = torch.aten.transpose.int %21, %int-2_735, %int-1_736 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_737 = torch.constant.int 5 + %881 = torch.prims.convert_element_type %880, %int5_737 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_738 = torch.constant.int 4096 + %882 = torch.prim.ListConstruct %342, %int4096_738 : (!torch.int, !torch.int) -> !torch.list + %883 = torch.aten.view %879, %882 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %883, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %884 = torch.aten.mm %883, %881 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %884, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_739 = torch.constant.int 4 + %int4096_740 = torch.constant.int 4096 + %885 = torch.prim.ListConstruct %int4_739, %298, %int4096_740 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %886 = torch.aten.view %884, %885 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %886, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_741 = torch.constant.int -2 + %int-1_742 = torch.constant.int -1 + %887 = torch.aten.transpose.int %22, %int-2_741, %int-1_742 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_743 = torch.constant.int 5 + %888 = torch.prims.convert_element_type %887, %int5_743 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_744 = torch.constant.int 4096 + %889 = torch.prim.ListConstruct %342, %int4096_744 : (!torch.int, !torch.int) -> !torch.list + %890 = torch.aten.view %879, %889 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %890, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %891 = torch.aten.mm %890, %888 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %891, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_745 = torch.constant.int 4 + %int1024_746 = torch.constant.int 1024 + %892 = torch.prim.ListConstruct %int4_745, %298, %int1024_746 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %893 = torch.aten.view %891, %892 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %893, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_747 = torch.constant.int -2 + %int-1_748 = torch.constant.int -1 + %894 = torch.aten.transpose.int %23, %int-2_747, %int-1_748 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> %int5_749 = torch.constant.int 5 - %898 = torch.prims.convert_element_type %897, %int5_749 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %898, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_750 = torch.constant.int -2 - %int-1_751 = torch.constant.int -1 - %899 = torch.aten.transpose.int %29, %int-2_750, %int-1_751 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_752 = torch.constant.int 4 - %900 = torch.aten.mul.int %int4_752, %306 : !torch.int, !torch.int -> !torch.int - %int4096_753 = torch.constant.int 4096 - %901 = torch.prim.ListConstruct %900, %int4096_753 : (!torch.int, !torch.int) -> !torch.list - %902 = torch.aten.view %898, %901 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %902, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %903 = torch.aten.mm %902, %899 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %903, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_754 = torch.constant.int 4 - %int4096_755 = torch.constant.int 4096 - %904 = torch.prim.ListConstruct %int4_754, %306, %int4096_755 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %905 = torch.aten.view %903, %904 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %905, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_756 = torch.constant.int -2 - %int-1_757 = torch.constant.int -1 - %906 = torch.aten.transpose.int %30, %int-2_756, %int-1_757 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_758 = torch.constant.int 4 - %907 = torch.aten.mul.int %int4_758, %306 : !torch.int, !torch.int -> !torch.int - %int4096_759 = torch.constant.int 4096 - %908 = torch.prim.ListConstruct %907, %int4096_759 : (!torch.int, !torch.int) -> !torch.list - %909 = torch.aten.view %898, %908 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %909, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %910 = torch.aten.mm %909, %906 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %910, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_760 = torch.constant.int 4 - %int1024_761 = torch.constant.int 1024 - %911 = torch.prim.ListConstruct %int4_760, %306, %int1024_761 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %912 = torch.aten.view %910, %911 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %912, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_762 = torch.constant.int -2 - %int-1_763 = torch.constant.int -1 - %913 = torch.aten.transpose.int %31, %int-2_762, %int-1_763 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_764 = torch.constant.int 4 - %914 = torch.aten.mul.int %int4_764, %306 : !torch.int, !torch.int -> !torch.int - %int4096_765 = torch.constant.int 4096 - %915 = torch.prim.ListConstruct %914, %int4096_765 : (!torch.int, !torch.int) -> !torch.list - %916 = torch.aten.view %898, %915 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %916, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %917 = torch.aten.mm %916, %913 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %917, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_766 = torch.constant.int 4 - %int1024_767 = torch.constant.int 1024 - %918 = torch.prim.ListConstruct %int4_766, %306, %int1024_767 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %919 = torch.aten.view %917, %918 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %919, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_768 = torch.constant.int 4 - %int32_769 = torch.constant.int 32 - %int128_770 = torch.constant.int 128 - %920 = torch.prim.ListConstruct %int4_768, %306, %int32_769, %int128_770 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %921 = torch.aten.view %905, %920 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %921, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_771 = torch.constant.int 4 - %int8_772 = torch.constant.int 8 - %int128_773 = torch.constant.int 128 - %922 = torch.prim.ListConstruct %int4_771, %306, %int8_772, %int128_773 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %923 = torch.aten.view %912, %922 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %923, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_774 = torch.constant.int 4 - %int8_775 = torch.constant.int 8 - %int128_776 = torch.constant.int 128 - %924 = torch.prim.ListConstruct %int4_774, %306, %int8_775, %int128_776 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %925 = torch.aten.view %919, %924 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %925, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_777 = torch.constant.int 131072 - %none_778 = torch.constant.none - %none_779 = torch.constant.none - %cpu_780 = torch.constant.device "cpu" - %false_781 = torch.constant.bool false - %926 = torch.aten.arange %int131072_777, %none_778, %none_779, %cpu_780, %false_781 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_782 = torch.constant.int 0 - %int128_783 = torch.constant.int 128 - %none_784 = torch.constant.none - %none_785 = torch.constant.none - %cpu_786 = torch.constant.device "cpu" - %false_787 = torch.constant.bool false - %927 = torch.aten.arange.start %int0_782, %int128_783, %none_784, %none_785, %cpu_786, %false_787 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_788 = torch.constant.int 2 - %928 = torch.aten.floor_divide.Scalar %927, %int2_788 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_789 = torch.constant.int 6 - %929 = torch.prims.convert_element_type %928, %int6_789 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_790 = torch.constant.int 128 - %930 = torch.aten.div.Scalar %929, %int128_790 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_791 = torch.constant.float 2.000000e+00 - %931 = torch.aten.mul.Scalar %930, %float2.000000e00_791 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_792 = torch.constant.float 5.000000e+05 - %932 = torch.aten.pow.Scalar %float5.000000e05_792, %931 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %933 = torch.aten.reciprocal %932 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_793 = torch.constant.float 1.000000e+00 - %934 = torch.aten.mul.Scalar %933, %float1.000000e00_793 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_794 = torch.constant.int 1 - %935 = torch.aten.unsqueeze %926, %int1_794 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %895 = torch.prims.convert_element_type %894, %int5_749 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_750 = torch.constant.int 4096 + %896 = torch.prim.ListConstruct %342, %int4096_750 : (!torch.int, !torch.int) -> !torch.list + %897 = torch.aten.view %879, %896 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %897, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %898 = torch.aten.mm %897, %895 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %898, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_751 = torch.constant.int 4 + %int1024_752 = torch.constant.int 1024 + %899 = torch.prim.ListConstruct %int4_751, %298, %int1024_752 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %900 = torch.aten.view %898, %899 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %900, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_753 = torch.constant.int 4 + %int32_754 = torch.constant.int 32 + %int128_755 = torch.constant.int 128 + %901 = torch.prim.ListConstruct %int4_753, %298, %int32_754, %int128_755 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %902 = torch.aten.view %886, %901 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %902, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_756 = torch.constant.int 4 + %int8_757 = torch.constant.int 8 + %int128_758 = torch.constant.int 128 + %903 = torch.prim.ListConstruct %int4_756, %298, %int8_757, %int128_758 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %904 = torch.aten.view %893, %903 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %904, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_759 = torch.constant.int 4 + %int8_760 = torch.constant.int 8 + %int128_761 = torch.constant.int 128 + %905 = torch.prim.ListConstruct %int4_759, %298, %int8_760, %int128_761 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %906 = torch.aten.view %900, %905 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %906, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_762 = torch.constant.int 131072 + %none_763 = torch.constant.none + %none_764 = torch.constant.none + %cpu_765 = torch.constant.device "cpu" + %false_766 = torch.constant.bool false + %907 = torch.aten.arange %int131072_762, %none_763, %none_764, %cpu_765, %false_766 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_767 = torch.constant.int 0 + %int128_768 = torch.constant.int 128 + %int2_769 = torch.constant.int 2 + %int4_770 = torch.constant.int 4 + %none_771 = torch.constant.none + %cpu_772 = torch.constant.device "cpu" + %false_773 = torch.constant.bool false + %908 = torch.aten.arange.start_step %int0_767, %int128_768, %int2_769, %int4_770, %none_771, %cpu_772, %false_773 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_774 = torch.constant.int 6 + %909 = torch.prims.convert_element_type %908, %int6_774 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_775 = torch.constant.int 128 + %910 = torch.aten.div.Scalar %909, %int128_775 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_776 = torch.constant.float 5.000000e+05 + %911 = torch.aten.pow.Scalar %float5.000000e05_776, %910 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %912 = torch.aten.reciprocal %911 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_777 = torch.constant.float 1.000000e+00 + %913 = torch.aten.mul.Scalar %912, %float1.000000e00_777 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %914 = torch.aten.reciprocal %913 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_778 = torch.constant.float 6.2831853071795862 + %915 = torch.aten.mul.Scalar %914, %float6.283190e00_778 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_779 = torch.constant.float 8.192000e+03 + %916 = torch.aten.gt.Scalar %915, %float8.192000e03_779 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_780 = torch.constant.int 8 + %917 = torch.aten.div.Scalar %913, %int8_780 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %918 = torch.aten.where.self %916, %917, %913 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %919 = torch.aten.reciprocal %915 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_781 = torch.constant.int 8192 + %920 = torch.aten.mul.Scalar %919, %int8192_781 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_782 = torch.constant.int 1 + %int1_783 = torch.constant.int 1 + %921 = torch.aten.sub.Scalar %920, %int1_782, %int1_783 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_784 = torch.constant.int 3 + %922 = torch.aten.div.Scalar %921, %int3_784 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_785 = torch.constant.int 1 + %int1_786 = torch.constant.int 1 + %923 = torch.aten.rsub.Scalar %922, %int1_785, %int1_786 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %924 = torch.aten.mul.Tensor %923, %918 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_787 = torch.constant.int 8 + %925 = torch.aten.div.Scalar %924, %int8_787 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %926 = torch.aten.mul.Tensor %922, %918 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_788 = torch.constant.int 1 + %927 = torch.aten.add.Tensor %925, %926, %int1_788 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_789 = torch.constant.float 2.048000e+03 + %928 = torch.aten.lt.Scalar %915, %float2.048000e03_789 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %929 = torch.aten.bitwise_not %928 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_790 = torch.constant.float 8.192000e+03 + %930 = torch.aten.gt.Scalar %915, %float8.192000e03_790 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %931 = torch.aten.bitwise_not %930 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %932 = torch.aten.mul.Tensor %929, %931 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %933 = torch.aten.where.self %932, %927, %918 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %934 = torch.prim.ListConstruct %933, %933 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_791 = torch.constant.int -1 + %935 = torch.aten.cat %934, %int-1_791 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_792 = torch.constant.int 6 + %936 = torch.prims.convert_element_type %935, %int6_792 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_793 = torch.constant.int 1 + %937 = torch.aten.unsqueeze %907, %int1_793 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_794 = torch.constant.int 6 + %938 = torch.prims.convert_element_type %937, %int6_794 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> %int0_795 = torch.constant.int 0 - %936 = torch.aten.unsqueeze %934, %int0_795 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %937 = torch.aten.mul.Tensor %935, %936 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_796 = torch.constant.int 1 - %938 = torch.aten.size.int %905, %int1_796 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_797 = torch.constant.int 0 - %939 = torch.aten.add.int %int0_797, %938 : !torch.int, !torch.int -> !torch.int - %int0_798 = torch.constant.int 0 + %939 = torch.aten.unsqueeze %936, %int0_795 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_796 = torch.constant.int 6 + %940 = torch.prims.convert_element_type %939, %int6_796 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %941 = torch.aten.mul.Tensor %938, %940 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %942 = torch.aten.cos %941 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_797 = torch.constant.int 5 + %943 = torch.prims.convert_element_type %942, %int5_797 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %944 = torch.aten.sin %941 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_798 = torch.constant.int 5 + %945 = torch.prims.convert_element_type %944, %int5_798 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> %int0_799 = torch.constant.int 0 - %int1_800 = torch.constant.int 1 - %940 = torch.aten.slice.Tensor %937, %int0_798, %int0_799, %939, %int1_800 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %940, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %int0_800 = torch.constant.int 0 %int1_801 = torch.constant.int 1 - %int0_802 = torch.constant.int 0 - %int9223372036854775807_803 = torch.constant.int 9223372036854775807 - %int1_804 = torch.constant.int 1 - %941 = torch.aten.slice.Tensor %940, %int1_801, %int0_802, %int9223372036854775807_803, %int1_804 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %941, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %946 = torch.aten.slice.Tensor %943, %int0_799, %int0_800, %298, %int1_801 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %946, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_802 = torch.constant.int 1 + %int0_803 = torch.constant.int 0 + %int9223372036854775807_804 = torch.constant.int 9223372036854775807 %int1_805 = torch.constant.int 1 + %947 = torch.aten.slice.Tensor %946, %int1_802, %int0_803, %int9223372036854775807_804, %int1_805 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %947, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int0_806 = torch.constant.int 0 - %int9223372036854775807_807 = torch.constant.int 9223372036854775807 + %int0_807 = torch.constant.int 0 %int1_808 = torch.constant.int 1 - %942 = torch.aten.slice.Tensor %941, %int1_805, %int0_806, %int9223372036854775807_807, %int1_808 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %942, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_809 = torch.constant.int 0 - %943 = torch.aten.unsqueeze %942, %int0_809 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %943, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_810 = torch.constant.int 1 - %int0_811 = torch.constant.int 0 - %int9223372036854775807_812 = torch.constant.int 9223372036854775807 - %int1_813 = torch.constant.int 1 - %944 = torch.aten.slice.Tensor %943, %int1_810, %int0_811, %int9223372036854775807_812, %int1_813 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %944, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_814 = torch.constant.int 2 + %948 = torch.aten.slice.Tensor %945, %int0_806, %int0_807, %298, %int1_808 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %948, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_809 = torch.constant.int 1 + %int0_810 = torch.constant.int 0 + %int9223372036854775807_811 = torch.constant.int 9223372036854775807 + %int1_812 = torch.constant.int 1 + %949 = torch.aten.slice.Tensor %948, %int1_809, %int0_810, %int9223372036854775807_811, %int1_812 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %949, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_813 = torch.constant.int 0 + %950 = torch.aten.unsqueeze %947, %int0_813 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %950, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_814 = torch.constant.int 1 %int0_815 = torch.constant.int 0 %int9223372036854775807_816 = torch.constant.int 9223372036854775807 %int1_817 = torch.constant.int 1 - %945 = torch.aten.slice.Tensor %944, %int2_814, %int0_815, %int9223372036854775807_816, %int1_817 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %945, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_818 = torch.constant.int 4 - %int1_819 = torch.constant.int 1 - %int1_820 = torch.constant.int 1 - %946 = torch.prim.ListConstruct %int4_818, %int1_819, %int1_820 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %947 = torch.aten.repeat %945, %946 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %947, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_821 = torch.constant.int 6 - %948 = torch.prims.convert_element_type %921, %int6_821 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %948, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %949 = torch_c.to_builtin_tensor %948 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %950 = torch_c.to_builtin_tensor %947 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %951 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%949, %950) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %952 = torch_c.from_builtin_tensor %951 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %952, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_822 = torch.constant.int 5 - %953 = torch.prims.convert_element_type %952, %int5_822 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %953, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_823 = torch.constant.int 131072 - %none_824 = torch.constant.none - %none_825 = torch.constant.none - %cpu_826 = torch.constant.device "cpu" - %false_827 = torch.constant.bool false - %954 = torch.aten.arange %int131072_823, %none_824, %none_825, %cpu_826, %false_827 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_828 = torch.constant.int 0 - %int128_829 = torch.constant.int 128 - %none_830 = torch.constant.none - %none_831 = torch.constant.none - %cpu_832 = torch.constant.device "cpu" - %false_833 = torch.constant.bool false - %955 = torch.aten.arange.start %int0_828, %int128_829, %none_830, %none_831, %cpu_832, %false_833 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_834 = torch.constant.int 2 - %956 = torch.aten.floor_divide.Scalar %955, %int2_834 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_835 = torch.constant.int 6 - %957 = torch.prims.convert_element_type %956, %int6_835 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_836 = torch.constant.int 128 - %958 = torch.aten.div.Scalar %957, %int128_836 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_837 = torch.constant.float 2.000000e+00 - %959 = torch.aten.mul.Scalar %958, %float2.000000e00_837 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_838 = torch.constant.float 5.000000e+05 - %960 = torch.aten.pow.Scalar %float5.000000e05_838, %959 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %961 = torch.aten.reciprocal %960 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_839 = torch.constant.float 1.000000e+00 - %962 = torch.aten.mul.Scalar %961, %float1.000000e00_839 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %951 = torch.aten.slice.Tensor %950, %int1_814, %int0_815, %int9223372036854775807_816, %int1_817 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %951, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_818 = torch.constant.int 2 + %952 = torch.aten.unsqueeze %951, %int2_818 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %952, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_819 = torch.constant.int 3 + %int0_820 = torch.constant.int 0 + %int9223372036854775807_821 = torch.constant.int 9223372036854775807 + %int1_822 = torch.constant.int 1 + %953 = torch.aten.slice.Tensor %952, %int3_819, %int0_820, %int9223372036854775807_821, %int1_822 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %953, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_823 = torch.constant.int 4 + %int1_824 = torch.constant.int 1 + %int1_825 = torch.constant.int 1 + %int1_826 = torch.constant.int 1 + %954 = torch.prim.ListConstruct %int4_823, %int1_824, %int1_825, %int1_826 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %955 = torch.aten.repeat %953, %954 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %955, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_827 = torch.constant.int 0 + %956 = torch.aten.unsqueeze %949, %int0_827 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %956, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_828 = torch.constant.int 1 + %int0_829 = torch.constant.int 0 + %int9223372036854775807_830 = torch.constant.int 9223372036854775807 + %int1_831 = torch.constant.int 1 + %957 = torch.aten.slice.Tensor %956, %int1_828, %int0_829, %int9223372036854775807_830, %int1_831 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %957, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_832 = torch.constant.int 2 + %958 = torch.aten.unsqueeze %957, %int2_832 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %958, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_833 = torch.constant.int 3 + %int0_834 = torch.constant.int 0 + %int9223372036854775807_835 = torch.constant.int 9223372036854775807 + %int1_836 = torch.constant.int 1 + %959 = torch.aten.slice.Tensor %958, %int3_833, %int0_834, %int9223372036854775807_835, %int1_836 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %959, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_837 = torch.constant.int 4 + %int1_838 = torch.constant.int 1 + %int1_839 = torch.constant.int 1 %int1_840 = torch.constant.int 1 - %963 = torch.aten.unsqueeze %954, %int1_840 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_841 = torch.constant.int 0 - %964 = torch.aten.unsqueeze %962, %int0_841 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %965 = torch.aten.mul.Tensor %963, %964 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_842 = torch.constant.int 1 - %966 = torch.aten.size.int %912, %int1_842 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_843 = torch.constant.int 0 - %967 = torch.aten.add.int %int0_843, %966 : !torch.int, !torch.int -> !torch.int - %int0_844 = torch.constant.int 0 - %int0_845 = torch.constant.int 0 - %int1_846 = torch.constant.int 1 - %968 = torch.aten.slice.Tensor %965, %int0_844, %int0_845, %967, %int1_846 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %968, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_847 = torch.constant.int 1 - %int0_848 = torch.constant.int 0 - %int9223372036854775807_849 = torch.constant.int 9223372036854775807 + %960 = torch.prim.ListConstruct %int4_837, %int1_838, %int1_839, %int1_840 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %961 = torch.aten.repeat %959, %960 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %961, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %962 = torch.aten.mul.Tensor %902, %955 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %962, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_841 = torch.constant.int 3 + %int0_842 = torch.constant.int 0 + %int64_843 = torch.constant.int 64 + %int1_844 = torch.constant.int 1 + %963 = torch.aten.slice.Tensor %902, %int3_841, %int0_842, %int64_843, %int1_844 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %963, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_845 = torch.constant.int 3 + %int64_846 = torch.constant.int 64 + %int9223372036854775807_847 = torch.constant.int 9223372036854775807 + %int1_848 = torch.constant.int 1 + %964 = torch.aten.slice.Tensor %902, %int3_845, %int64_846, %int9223372036854775807_847, %int1_848 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %964, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %965 = torch.aten.neg %964 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %965, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %966 = torch.prim.ListConstruct %965, %963 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_849 = torch.constant.int -1 + %967 = torch.aten.cat %966, %int-1_849 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %967, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %968 = torch.aten.mul.Tensor %967, %961 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %968, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int1_850 = torch.constant.int 1 - %969 = torch.aten.slice.Tensor %968, %int1_847, %int0_848, %int9223372036854775807_849, %int1_850 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %969, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_851 = torch.constant.int 1 - %int0_852 = torch.constant.int 0 - %int9223372036854775807_853 = torch.constant.int 9223372036854775807 - %int1_854 = torch.constant.int 1 - %970 = torch.aten.slice.Tensor %969, %int1_851, %int0_852, %int9223372036854775807_853, %int1_854 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %970, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_855 = torch.constant.int 0 - %971 = torch.aten.unsqueeze %970, %int0_855 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %971, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_856 = torch.constant.int 1 - %int0_857 = torch.constant.int 0 - %int9223372036854775807_858 = torch.constant.int 9223372036854775807 - %int1_859 = torch.constant.int 1 - %972 = torch.aten.slice.Tensor %971, %int1_856, %int0_857, %int9223372036854775807_858, %int1_859 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %972, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_860 = torch.constant.int 2 - %int0_861 = torch.constant.int 0 - %int9223372036854775807_862 = torch.constant.int 9223372036854775807 - %int1_863 = torch.constant.int 1 - %973 = torch.aten.slice.Tensor %972, %int2_860, %int0_861, %int9223372036854775807_862, %int1_863 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %973, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_864 = torch.constant.int 4 - %int1_865 = torch.constant.int 1 - %int1_866 = torch.constant.int 1 - %974 = torch.prim.ListConstruct %int4_864, %int1_865, %int1_866 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %975 = torch.aten.repeat %973, %974 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %975, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_867 = torch.constant.int 6 - %976 = torch.prims.convert_element_type %923, %int6_867 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %976, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %977 = torch_c.to_builtin_tensor %976 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %978 = torch_c.to_builtin_tensor %975 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %979 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%977, %978) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %980 = torch_c.from_builtin_tensor %979 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %980, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_868 = torch.constant.int 5 - %981 = torch.prims.convert_element_type %980, %int5_868 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %981, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_869 = torch.constant.int 64 - %982 = torch.aten.mul.Scalar %arg2, %int64_869 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %982, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int6_870 = torch.constant.int 6 + %969 = torch.aten.add.Tensor %962, %968, %int1_850 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %969, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_851 = torch.constant.int 131072 + %none_852 = torch.constant.none + %none_853 = torch.constant.none + %cpu_854 = torch.constant.device "cpu" + %false_855 = torch.constant.bool false + %970 = torch.aten.arange %int131072_851, %none_852, %none_853, %cpu_854, %false_855 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_856 = torch.constant.int 0 + %int128_857 = torch.constant.int 128 + %int2_858 = torch.constant.int 2 + %int4_859 = torch.constant.int 4 + %none_860 = torch.constant.none + %cpu_861 = torch.constant.device "cpu" + %false_862 = torch.constant.bool false + %971 = torch.aten.arange.start_step %int0_856, %int128_857, %int2_858, %int4_859, %none_860, %cpu_861, %false_862 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_863 = torch.constant.int 6 + %972 = torch.prims.convert_element_type %971, %int6_863 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_864 = torch.constant.int 128 + %973 = torch.aten.div.Scalar %972, %int128_864 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_865 = torch.constant.float 5.000000e+05 + %974 = torch.aten.pow.Scalar %float5.000000e05_865, %973 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %975 = torch.aten.reciprocal %974 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_866 = torch.constant.float 1.000000e+00 + %976 = torch.aten.mul.Scalar %975, %float1.000000e00_866 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %977 = torch.aten.reciprocal %976 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_867 = torch.constant.float 6.2831853071795862 + %978 = torch.aten.mul.Scalar %977, %float6.283190e00_867 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_868 = torch.constant.float 8.192000e+03 + %979 = torch.aten.gt.Scalar %978, %float8.192000e03_868 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_869 = torch.constant.int 8 + %980 = torch.aten.div.Scalar %976, %int8_869 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %981 = torch.aten.where.self %979, %980, %976 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %982 = torch.aten.reciprocal %978 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_870 = torch.constant.int 8192 + %983 = torch.aten.mul.Scalar %982, %int8192_870 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_871 = torch.constant.int 1 - %983 = torch.aten.add.Scalar %982, %int6_870, %int1_871 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %983, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_872 = torch.constant.int 4 - %int32_873 = torch.constant.int 32 - %int8_874 = torch.constant.int 8 - %int128_875 = torch.constant.int 128 - %984 = torch.prim.ListConstruct %int4_872, %398, %int32_873, %int8_874, %int128_875 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %985 = torch.aten.view %981, %984 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %985, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_876 = torch.constant.int 4 - %986 = torch.aten.mul.int %int4_876, %398 : !torch.int, !torch.int -> !torch.int - %int32_877 = torch.constant.int 32 - %int8_878 = torch.constant.int 8 - %int128_879 = torch.constant.int 128 - %987 = torch.prim.ListConstruct %986, %int32_877, %int8_878, %int128_879 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %988 = torch.aten.view %985, %987 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %988, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_880 = torch.constant.int 4 - %989 = torch.aten.mul.int %int4_880, %398 : !torch.int, !torch.int -> !torch.int - %990 = torch.prim.ListConstruct %989 : (!torch.int) -> !torch.list - %991 = torch.aten.view %983, %990 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %991, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_881 = torch.constant.int 32 - %int2_882 = torch.constant.int 2 - %int32_883 = torch.constant.int 32 - %int8_884 = torch.constant.int 8 - %int128_885 = torch.constant.int 128 - %992 = torch.prim.ListConstruct %389, %int32_881, %int2_882, %int32_883, %int8_884, %int128_885 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %993 = torch.aten.view %825, %992 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %993, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_886 = torch.constant.int 32 - %994 = torch.aten.mul.int %389, %int32_886 : !torch.int, !torch.int -> !torch.int - %int2_887 = torch.constant.int 2 - %995 = torch.aten.mul.int %994, %int2_887 : !torch.int, !torch.int -> !torch.int - %int32_888 = torch.constant.int 32 - %int8_889 = torch.constant.int 8 - %int128_890 = torch.constant.int 128 - %996 = torch.prim.ListConstruct %995, %int32_888, %int8_889, %int128_890 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %997 = torch.aten.view %993, %996 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %997, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %998 = torch.prim.ListConstruct %991 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_891 = torch.constant.bool false - %999 = torch.aten.index_put %997, %998, %988, %false_891 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %999, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_892 = torch.constant.int 32 - %int2_893 = torch.constant.int 2 - %int32_894 = torch.constant.int 32 - %int8_895 = torch.constant.int 8 - %int128_896 = torch.constant.int 128 - %1000 = torch.prim.ListConstruct %389, %int32_892, %int2_893, %int32_894, %int8_895, %int128_896 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1001 = torch.aten.view %999, %1000 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1001, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_897 = torch.constant.int 2097152 - %1002 = torch.prim.ListConstruct %389, %int2097152_897 : (!torch.int, !torch.int) -> !torch.list - %1003 = torch.aten.view %1001, %1002 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1003, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_898 = torch.constant.int 32 - %int2_899 = torch.constant.int 2 - %int32_900 = torch.constant.int 32 - %int8_901 = torch.constant.int 8 - %int128_902 = torch.constant.int 128 - %1004 = torch.prim.ListConstruct %389, %int32_898, %int2_899, %int32_900, %int8_901, %int128_902 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1005 = torch.aten.view %1003, %1004 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1005, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_903 = torch.constant.int 32 - %int8_904 = torch.constant.int 8 - %int128_905 = torch.constant.int 128 - %1006 = torch.prim.ListConstruct %995, %int32_903, %int8_904, %int128_905 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1007 = torch.aten.view %1005, %1006 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1007, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_906 = torch.constant.int 4 - %int32_907 = torch.constant.int 32 - %int8_908 = torch.constant.int 8 - %int128_909 = torch.constant.int 128 - %1008 = torch.prim.ListConstruct %int4_906, %398, %int32_907, %int8_908, %int128_909 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1009 = torch.aten.view %925, %1008 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1009, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_910 = torch.constant.int 4 - %1010 = torch.aten.mul.int %int4_910, %398 : !torch.int, !torch.int -> !torch.int - %int32_911 = torch.constant.int 32 - %int8_912 = torch.constant.int 8 - %int128_913 = torch.constant.int 128 - %1011 = torch.prim.ListConstruct %1010, %int32_911, %int8_912, %int128_913 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1012 = torch.aten.view %1009, %1011 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1012, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_872 = torch.constant.int 1 + %984 = torch.aten.sub.Scalar %983, %int1_871, %int1_872 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_873 = torch.constant.int 3 + %985 = torch.aten.div.Scalar %984, %int3_873 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_874 = torch.constant.int 1 + %int1_875 = torch.constant.int 1 + %986 = torch.aten.rsub.Scalar %985, %int1_874, %int1_875 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %987 = torch.aten.mul.Tensor %986, %981 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_876 = torch.constant.int 8 + %988 = torch.aten.div.Scalar %987, %int8_876 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %989 = torch.aten.mul.Tensor %985, %981 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_877 = torch.constant.int 1 + %990 = torch.aten.add.Tensor %988, %989, %int1_877 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_878 = torch.constant.float 2.048000e+03 + %991 = torch.aten.lt.Scalar %978, %float2.048000e03_878 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %992 = torch.aten.bitwise_not %991 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_879 = torch.constant.float 8.192000e+03 + %993 = torch.aten.gt.Scalar %978, %float8.192000e03_879 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %994 = torch.aten.bitwise_not %993 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %995 = torch.aten.mul.Tensor %992, %994 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %996 = torch.aten.where.self %995, %990, %981 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %997 = torch.prim.ListConstruct %996, %996 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_880 = torch.constant.int -1 + %998 = torch.aten.cat %997, %int-1_880 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_881 = torch.constant.int 6 + %999 = torch.prims.convert_element_type %998, %int6_881 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_882 = torch.constant.int 1 + %1000 = torch.aten.unsqueeze %970, %int1_882 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_883 = torch.constant.int 6 + %1001 = torch.prims.convert_element_type %1000, %int6_883 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_884 = torch.constant.int 0 + %1002 = torch.aten.unsqueeze %999, %int0_884 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_885 = torch.constant.int 6 + %1003 = torch.prims.convert_element_type %1002, %int6_885 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %1004 = torch.aten.mul.Tensor %1001, %1003 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %1005 = torch.aten.cos %1004 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_886 = torch.constant.int 5 + %1006 = torch.prims.convert_element_type %1005, %int5_886 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %1007 = torch.aten.sin %1004 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_887 = torch.constant.int 5 + %1008 = torch.prims.convert_element_type %1007, %int5_887 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_888 = torch.constant.int 0 + %int0_889 = torch.constant.int 0 + %int1_890 = torch.constant.int 1 + %1009 = torch.aten.slice.Tensor %1006, %int0_888, %int0_889, %298, %int1_890 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1009, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_891 = torch.constant.int 1 + %int0_892 = torch.constant.int 0 + %int9223372036854775807_893 = torch.constant.int 9223372036854775807 + %int1_894 = torch.constant.int 1 + %1010 = torch.aten.slice.Tensor %1009, %int1_891, %int0_892, %int9223372036854775807_893, %int1_894 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1010, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_895 = torch.constant.int 0 + %int0_896 = torch.constant.int 0 + %int1_897 = torch.constant.int 1 + %1011 = torch.aten.slice.Tensor %1008, %int0_895, %int0_896, %298, %int1_897 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1011, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_898 = torch.constant.int 1 + %int0_899 = torch.constant.int 0 + %int9223372036854775807_900 = torch.constant.int 9223372036854775807 + %int1_901 = torch.constant.int 1 + %1012 = torch.aten.slice.Tensor %1011, %int1_898, %int0_899, %int9223372036854775807_900, %int1_901 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1012, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_902 = torch.constant.int 0 + %1013 = torch.aten.unsqueeze %1010, %int0_902 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1013, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_903 = torch.constant.int 1 + %int0_904 = torch.constant.int 0 + %int9223372036854775807_905 = torch.constant.int 9223372036854775807 + %int1_906 = torch.constant.int 1 + %1014 = torch.aten.slice.Tensor %1013, %int1_903, %int0_904, %int9223372036854775807_905, %int1_906 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1014, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_907 = torch.constant.int 2 + %1015 = torch.aten.unsqueeze %1014, %int2_907 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1015, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_908 = torch.constant.int 3 + %int0_909 = torch.constant.int 0 + %int9223372036854775807_910 = torch.constant.int 9223372036854775807 + %int1_911 = torch.constant.int 1 + %1016 = torch.aten.slice.Tensor %1015, %int3_908, %int0_909, %int9223372036854775807_910, %int1_911 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1016, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_912 = torch.constant.int 4 + %int1_913 = torch.constant.int 1 %int1_914 = torch.constant.int 1 %int1_915 = torch.constant.int 1 - %1013 = torch.aten.add.Scalar %983, %int1_914, %int1_915 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1013, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_916 = torch.constant.int 4 - %1014 = torch.aten.mul.int %int4_916, %398 : !torch.int, !torch.int -> !torch.int - %1015 = torch.prim.ListConstruct %1014 : (!torch.int) -> !torch.list - %1016 = torch.aten.view %1013, %1015 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1016, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %1017 = torch.prim.ListConstruct %1016 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_917 = torch.constant.bool false - %1018 = torch.aten.index_put %1007, %1017, %1012, %false_917 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1018, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_918 = torch.constant.int 32 - %int2_919 = torch.constant.int 2 - %int32_920 = torch.constant.int 32 - %int8_921 = torch.constant.int 8 - %int128_922 = torch.constant.int 128 - %1019 = torch.prim.ListConstruct %389, %int32_918, %int2_919, %int32_920, %int8_921, %int128_922 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1020 = torch.aten.view %1018, %1019 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1020, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_923 = torch.constant.int 2097152 - %1021 = torch.prim.ListConstruct %389, %int2097152_923 : (!torch.int, !torch.int) -> !torch.list - %1022 = torch.aten.view %1020, %1021 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1022, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_924 = torch.constant.int -2 - %1023 = torch.aten.unsqueeze %981, %int-2_924 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1023, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_925 = torch.constant.int 4 - %int8_926 = torch.constant.int 8 - %int4_927 = torch.constant.int 4 - %int128_928 = torch.constant.int 128 - %1024 = torch.prim.ListConstruct %int4_925, %966, %int8_926, %int4_927, %int128_928 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_929 = torch.constant.bool false - %1025 = torch.aten.expand %1023, %1024, %false_929 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1025, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_930 = torch.constant.int 0 - %1026 = torch.aten.clone %1025, %int0_930 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1026, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_931 = torch.constant.int 4 - %int32_932 = torch.constant.int 32 - %int128_933 = torch.constant.int 128 - %1027 = torch.prim.ListConstruct %int4_931, %966, %int32_932, %int128_933 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1028 = torch.aten._unsafe_view %1026, %1027 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1028, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_934 = torch.constant.int -2 - %1029 = torch.aten.unsqueeze %925, %int-2_934 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1029, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_935 = torch.constant.int 1 - %1030 = torch.aten.size.int %919, %int1_935 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_936 = torch.constant.int 4 - %int8_937 = torch.constant.int 8 - %int4_938 = torch.constant.int 4 - %int128_939 = torch.constant.int 128 - %1031 = torch.prim.ListConstruct %int4_936, %1030, %int8_937, %int4_938, %int128_939 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_940 = torch.constant.bool false - %1032 = torch.aten.expand %1029, %1031, %false_940 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1032, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_941 = torch.constant.int 0 - %1033 = torch.aten.clone %1032, %int0_941 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1033, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_942 = torch.constant.int 4 - %int32_943 = torch.constant.int 32 - %int128_944 = torch.constant.int 128 - %1034 = torch.prim.ListConstruct %int4_942, %1030, %int32_943, %int128_944 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1035 = torch.aten._unsafe_view %1033, %1034 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1035, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %1017 = torch.prim.ListConstruct %int4_912, %int1_913, %int1_914, %int1_915 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1018 = torch.aten.repeat %1016, %1017 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %1018, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_916 = torch.constant.int 0 + %1019 = torch.aten.unsqueeze %1012, %int0_916 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1019, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_917 = torch.constant.int 1 + %int0_918 = torch.constant.int 0 + %int9223372036854775807_919 = torch.constant.int 9223372036854775807 + %int1_920 = torch.constant.int 1 + %1020 = torch.aten.slice.Tensor %1019, %int1_917, %int0_918, %int9223372036854775807_919, %int1_920 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1020, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_921 = torch.constant.int 2 + %1021 = torch.aten.unsqueeze %1020, %int2_921 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1021, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_922 = torch.constant.int 3 + %int0_923 = torch.constant.int 0 + %int9223372036854775807_924 = torch.constant.int 9223372036854775807 + %int1_925 = torch.constant.int 1 + %1022 = torch.aten.slice.Tensor %1021, %int3_922, %int0_923, %int9223372036854775807_924, %int1_925 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1022, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_926 = torch.constant.int 4 + %int1_927 = torch.constant.int 1 + %int1_928 = torch.constant.int 1 + %int1_929 = torch.constant.int 1 + %1023 = torch.prim.ListConstruct %int4_926, %int1_927, %int1_928, %int1_929 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1024 = torch.aten.repeat %1022, %1023 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %1024, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %1025 = torch.aten.mul.Tensor %904, %1018 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1025, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_930 = torch.constant.int 3 + %int0_931 = torch.constant.int 0 + %int64_932 = torch.constant.int 64 + %int1_933 = torch.constant.int 1 + %1026 = torch.aten.slice.Tensor %904, %int3_930, %int0_931, %int64_932, %int1_933 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %1026, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_934 = torch.constant.int 3 + %int64_935 = torch.constant.int 64 + %int9223372036854775807_936 = torch.constant.int 9223372036854775807 + %int1_937 = torch.constant.int 1 + %1027 = torch.aten.slice.Tensor %904, %int3_934, %int64_935, %int9223372036854775807_936, %int1_937 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %1027, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %1028 = torch.aten.neg %1027 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %1028, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %1029 = torch.prim.ListConstruct %1028, %1026 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_938 = torch.constant.int -1 + %1030 = torch.aten.cat %1029, %int-1_938 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1030, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %1031 = torch.aten.mul.Tensor %1030, %1024 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1031, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_939 = torch.constant.int 1 + %1032 = torch.aten.add.Tensor %1025, %1031, %int1_939 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1032, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_940 = torch.constant.int 32 + %1033 = torch.aten.mul.Scalar %arg2, %int32_940 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1033, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_941 = torch.constant.int 2 + %int1_942 = torch.constant.int 1 + %1034 = torch.aten.add.Scalar %1033, %int2_941, %int1_942 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1034, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_943 = torch.constant.int 2 + %1035 = torch.aten.mul.Scalar %1034, %int2_943 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1035, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_944 = torch.constant.int 0 %int1_945 = torch.constant.int 1 - %int2_946 = torch.constant.int 2 - %1036 = torch.aten.transpose.int %953, %int1_945, %int2_946 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1036, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_947 = torch.constant.int 1 - %int2_948 = torch.constant.int 2 - %1037 = torch.aten.transpose.int %1028, %int1_947, %int2_948 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1037, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_949 = torch.constant.int 1 - %int2_950 = torch.constant.int 2 - %1038 = torch.aten.transpose.int %1035, %int1_949, %int2_950 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1038, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_951 = torch.constant.float 0.000000e+00 - %true_952 = torch.constant.bool true - %none_953 = torch.constant.none - %none_954 = torch.constant.none - %1039:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1036, %1037, %1038, %float0.000000e00_951, %true_952, %none_953, %none_954) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %1039#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_955 = torch.constant.int 1 - %int2_956 = torch.constant.int 2 - %1040 = torch.aten.transpose.int %1039#0, %int1_955, %int2_956 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1040, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_957 = torch.constant.int 4 - %int4096_958 = torch.constant.int 4096 - %1041 = torch.prim.ListConstruct %int4_957, %938, %int4096_958 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1042 = torch.aten.view %1040, %1041 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1042, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_959 = torch.constant.int -2 - %int-1_960 = torch.constant.int -1 - %1043 = torch.aten.transpose.int %32, %int-2_959, %int-1_960 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_961 = torch.constant.int 4 - %1044 = torch.aten.mul.int %int4_961, %938 : !torch.int, !torch.int -> !torch.int - %int4096_962 = torch.constant.int 4096 - %1045 = torch.prim.ListConstruct %1044, %int4096_962 : (!torch.int, !torch.int) -> !torch.list - %1046 = torch.aten.view %1042, %1045 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1046, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1047 = torch.aten.mm %1046, %1043 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1047, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_963 = torch.constant.int 4 - %int4096_964 = torch.constant.int 4096 - %1048 = torch.prim.ListConstruct %int4_963, %938, %int4096_964 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1049 = torch.aten.view %1047, %1048 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1049, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_965 = torch.constant.int 1 - %1050 = torch.aten.add.Tensor %888, %1049, %int1_965 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1050, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_966 = torch.constant.int 6 - %1051 = torch.prims.convert_element_type %1050, %int6_966 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1051, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_967 = torch.constant.int 2 - %1052 = torch.aten.pow.Tensor_Scalar %1051, %int2_967 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1052, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_968 = torch.constant.int -1 - %1053 = torch.prim.ListConstruct %int-1_968 : (!torch.int) -> !torch.list - %true_969 = torch.constant.bool true - %none_970 = torch.constant.none - %1054 = torch.aten.mean.dim %1052, %1053, %true_969, %none_970 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1054, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_971 = torch.constant.float 9.9999997473787516E-6 - %int1_972 = torch.constant.int 1 - %1055 = torch.aten.add.Scalar %1054, %float9.999990e-06_971, %int1_972 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1055, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1056 = torch.aten.rsqrt %1055 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1056, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1057 = torch.aten.mul.Tensor %1051, %1056 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1057, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_973 = torch.constant.int 5 - %1058 = torch.prims.convert_element_type %1057, %int5_973 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1058, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %1059 = torch.aten.mul.Tensor %33, %1058 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1059, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_974 = torch.constant.int 5 - %1060 = torch.prims.convert_element_type %1059, %int5_974 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1060, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_975 = torch.constant.int -2 - %int-1_976 = torch.constant.int -1 - %1061 = torch.aten.transpose.int %34, %int-2_975, %int-1_976 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_977 = torch.constant.int 4 - %1062 = torch.aten.mul.int %int4_977, %306 : !torch.int, !torch.int -> !torch.int - %int4096_978 = torch.constant.int 4096 - %1063 = torch.prim.ListConstruct %1062, %int4096_978 : (!torch.int, !torch.int) -> !torch.list - %1064 = torch.aten.view %1060, %1063 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1064, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1065 = torch.aten.mm %1064, %1061 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1065, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_979 = torch.constant.int 4 - %int14336_980 = torch.constant.int 14336 - %1066 = torch.prim.ListConstruct %int4_979, %306, %int14336_980 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1067 = torch.aten.view %1065, %1066 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1067, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %1068 = torch.aten.silu %1067 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1068, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_981 = torch.constant.int -2 - %int-1_982 = torch.constant.int -1 - %1069 = torch.aten.transpose.int %35, %int-2_981, %int-1_982 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_983 = torch.constant.int 4 - %1070 = torch.aten.mul.int %int4_983, %306 : !torch.int, !torch.int -> !torch.int - %int4096_984 = torch.constant.int 4096 - %1071 = torch.prim.ListConstruct %1070, %int4096_984 : (!torch.int, !torch.int) -> !torch.list - %1072 = torch.aten.view %1060, %1071 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1072, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1073 = torch.aten.mm %1072, %1069 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1073, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %1036 = torch.aten.add.Scalar %1035, %int0_944, %int1_945 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1036, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %1037 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %1038 = torch.aten.view %1036, %1037 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %1038, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_946 = torch.constant.int 4 + %int32_947 = torch.constant.int 32 + %int8_948 = torch.constant.int 8 + %int128_949 = torch.constant.int 128 + %1039 = torch.prim.ListConstruct %int4_946, %296, %int32_947, %int8_948, %int128_949 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1040 = torch.aten.view %1032, %1039 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1040, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_950 = torch.constant.int 32 + %int8_951 = torch.constant.int 8 + %int128_952 = torch.constant.int 128 + %1041 = torch.prim.ListConstruct %504, %int32_950, %int8_951, %int128_952 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1042 = torch.aten.view %1040, %1041 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %1042, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_953 = torch.constant.int 1 + %int2_954 = torch.constant.int 2 + %1043 = torch.aten.transpose.int %1042, %int1_953, %int2_954 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1043, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_955 = torch.constant.int 5 + %1044 = torch.prims.convert_element_type %1043, %int5_955 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1044, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_956 = torch.constant.int 32 + %int2_957 = torch.constant.int 2 + %int8_958 = torch.constant.int 8 + %int32_959 = torch.constant.int 32 + %int128_960 = torch.constant.int 128 + %1045 = torch.prim.ListConstruct %297, %int32_956, %int2_957, %int8_958, %int32_959, %int128_960 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1046 = torch.aten.view %808, %1045 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1046, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_961 = torch.constant.int 8 + %int32_962 = torch.constant.int 32 + %int128_963 = torch.constant.int 128 + %1047 = torch.prim.ListConstruct %497, %int8_961, %int32_962, %int128_963 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1048 = torch.aten.view %1046, %1047 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1048, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %1049 = torch.prim.ListConstruct %1038 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_964 = torch.constant.bool false + %1050 = torch.aten.index_put %1048, %1049, %1044, %false_964 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1050, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_965 = torch.constant.int 32 + %int2_966 = torch.constant.int 2 + %int8_967 = torch.constant.int 8 + %int32_968 = torch.constant.int 32 + %int128_969 = torch.constant.int 128 + %1051 = torch.prim.ListConstruct %297, %int32_965, %int2_966, %int8_967, %int32_968, %int128_969 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1052 = torch.aten.view %1050, %1051 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1052, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_970 = torch.constant.int 2097152 + %1053 = torch.prim.ListConstruct %297, %int2097152_970 : (!torch.int, !torch.int) -> !torch.list + %1054 = torch.aten.view %1052, %1053 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1054, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_971 = torch.constant.int 32 + %int2_972 = torch.constant.int 2 + %int8_973 = torch.constant.int 8 + %int32_974 = torch.constant.int 32 + %int128_975 = torch.constant.int 128 + %1055 = torch.prim.ListConstruct %297, %int32_971, %int2_972, %int8_973, %int32_974, %int128_975 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1056 = torch.aten.view %1054, %1055 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1056, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_976 = torch.constant.int 8 + %int32_977 = torch.constant.int 32 + %int128_978 = torch.constant.int 128 + %1057 = torch.prim.ListConstruct %497, %int8_976, %int32_977, %int128_978 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1058 = torch.aten.view %1056, %1057 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1058, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_979 = torch.constant.int 32 + %1059 = torch.aten.mul.Scalar %arg2, %int32_979 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1059, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_980 = torch.constant.int 2 + %int1_981 = torch.constant.int 1 + %1060 = torch.aten.add.Scalar %1059, %int2_980, %int1_981 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1060, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_982 = torch.constant.int 2 + %1061 = torch.aten.mul.Scalar %1060, %int2_982 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1061, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_983 = torch.constant.int 1 + %int1_984 = torch.constant.int 1 + %1062 = torch.aten.add.Scalar %1061, %int1_983, %int1_984 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1062, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %1063 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %1064 = torch.aten.view %1062, %1063 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %1064, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> %int4_985 = torch.constant.int 4 - %int14336_986 = torch.constant.int 14336 - %1074 = torch.prim.ListConstruct %int4_985, %306, %int14336_986 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1075 = torch.aten.view %1073, %1074 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1075, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %1076 = torch.aten.mul.Tensor %1068, %1075 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1076, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_987 = torch.constant.int -2 - %int-1_988 = torch.constant.int -1 - %1077 = torch.aten.transpose.int %36, %int-2_987, %int-1_988 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_989 = torch.constant.int 1 - %1078 = torch.aten.size.int %1067, %int1_989 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_990 = torch.constant.int 4 - %1079 = torch.aten.mul.int %int4_990, %1078 : !torch.int, !torch.int -> !torch.int - %int14336_991 = torch.constant.int 14336 - %1080 = torch.prim.ListConstruct %1079, %int14336_991 : (!torch.int, !torch.int) -> !torch.list - %1081 = torch.aten.view %1076, %1080 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1081, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %1082 = torch.aten.mm %1081, %1077 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1082, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_992 = torch.constant.int 4 - %int4096_993 = torch.constant.int 4096 - %1083 = torch.prim.ListConstruct %int4_992, %1078, %int4096_993 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1084 = torch.aten.view %1082, %1083 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1084, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_994 = torch.constant.int 1 - %1085 = torch.aten.add.Tensor %1050, %1084, %int1_994 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1085, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_995 = torch.constant.int 6 - %1086 = torch.prims.convert_element_type %1085, %int6_995 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1086, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_996 = torch.constant.int 2 - %1087 = torch.aten.pow.Tensor_Scalar %1086, %int2_996 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1087, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_997 = torch.constant.int -1 - %1088 = torch.prim.ListConstruct %int-1_997 : (!torch.int) -> !torch.list - %true_998 = torch.constant.bool true - %none_999 = torch.constant.none - %1089 = torch.aten.mean.dim %1087, %1088, %true_998, %none_999 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1089, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_1000 = torch.constant.float 9.9999997473787516E-6 - %int1_1001 = torch.constant.int 1 - %1090 = torch.aten.add.Scalar %1089, %float9.999990e-06_1000, %int1_1001 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1090, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1091 = torch.aten.rsqrt %1090 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1091, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1092 = torch.aten.mul.Tensor %1086, %1091 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1092, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1002 = torch.constant.int 5 - %1093 = torch.prims.convert_element_type %1092, %int5_1002 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1093, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %1094 = torch.aten.mul.Tensor %37, %1093 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1094, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1003 = torch.constant.int 5 - %1095 = torch.prims.convert_element_type %1094, %int5_1003 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1095, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1004 = torch.constant.int -2 - %int-1_1005 = torch.constant.int -1 - %1096 = torch.aten.transpose.int %38, %int-2_1004, %int-1_1005 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1006 = torch.constant.int 4 - %1097 = torch.aten.mul.int %int4_1006, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1007 = torch.constant.int 4096 - %1098 = torch.prim.ListConstruct %1097, %int4096_1007 : (!torch.int, !torch.int) -> !torch.list + %int32_986 = torch.constant.int 32 + %int8_987 = torch.constant.int 8 + %int128_988 = torch.constant.int 128 + %1065 = torch.prim.ListConstruct %int4_985, %296, %int32_986, %int8_987, %int128_988 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1066 = torch.aten.view %906, %1065 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1066, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_989 = torch.constant.int 32 + %int8_990 = torch.constant.int 8 + %int128_991 = torch.constant.int 128 + %1067 = torch.prim.ListConstruct %504, %int32_989, %int8_990, %int128_991 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1068 = torch.aten.view %1066, %1067 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %1068, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_992 = torch.constant.int 1 + %int2_993 = torch.constant.int 2 + %1069 = torch.aten.transpose.int %1068, %int1_992, %int2_993 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1069, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_994 = torch.constant.int 5 + %1070 = torch.prims.convert_element_type %1069, %int5_994 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1070, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %1071 = torch.prim.ListConstruct %1064 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_995 = torch.constant.bool false + %1072 = torch.aten.index_put %1058, %1071, %1070, %false_995 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1072, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_996 = torch.constant.int 32 + %int2_997 = torch.constant.int 2 + %int8_998 = torch.constant.int 8 + %int32_999 = torch.constant.int 32 + %int128_1000 = torch.constant.int 128 + %1073 = torch.prim.ListConstruct %297, %int32_996, %int2_997, %int8_998, %int32_999, %int128_1000 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1074 = torch.aten.view %1072, %1073 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1074, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_1001 = torch.constant.int 2097152 + %1075 = torch.prim.ListConstruct %297, %int2097152_1001 : (!torch.int, !torch.int) -> !torch.list + %1076 = torch.aten.view %1074, %1075 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1076, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_1002 = torch.constant.int -2 + %1077 = torch.aten.unsqueeze %1032, %int-2_1002 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1077, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_1003 = torch.constant.int 4 + %int8_1004 = torch.constant.int 8 + %int4_1005 = torch.constant.int 4 + %int128_1006 = torch.constant.int 128 + %1078 = torch.prim.ListConstruct %int4_1003, %298, %int8_1004, %int4_1005, %int128_1006 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_1007 = torch.constant.bool false + %1079 = torch.aten.expand %1077, %1078, %false_1007 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1079, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_1008 = torch.constant.int 0 + %1080 = torch.aten.clone %1079, %int0_1008 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1080, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_1009 = torch.constant.int 4 + %int32_1010 = torch.constant.int 32 + %int128_1011 = torch.constant.int 128 + %1081 = torch.prim.ListConstruct %int4_1009, %298, %int32_1010, %int128_1011 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1082 = torch.aten._unsafe_view %1080, %1081 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1082, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_1012 = torch.constant.int -2 + %1083 = torch.aten.unsqueeze %906, %int-2_1012 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1083, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_1013 = torch.constant.int 4 + %int8_1014 = torch.constant.int 8 + %int4_1015 = torch.constant.int 4 + %int128_1016 = torch.constant.int 128 + %1084 = torch.prim.ListConstruct %int4_1013, %298, %int8_1014, %int4_1015, %int128_1016 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_1017 = torch.constant.bool false + %1085 = torch.aten.expand %1083, %1084, %false_1017 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1085, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_1018 = torch.constant.int 0 + %1086 = torch.aten.clone %1085, %int0_1018 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1086, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_1019 = torch.constant.int 4 + %int32_1020 = torch.constant.int 32 + %int128_1021 = torch.constant.int 128 + %1087 = torch.prim.ListConstruct %int4_1019, %298, %int32_1020, %int128_1021 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1088 = torch.aten._unsafe_view %1086, %1087 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1088, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_1022 = torch.constant.int 1 + %int2_1023 = torch.constant.int 2 + %1089 = torch.aten.transpose.int %969, %int1_1022, %int2_1023 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1089, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_1024 = torch.constant.int 1 + %int2_1025 = torch.constant.int 2 + %1090 = torch.aten.transpose.int %1082, %int1_1024, %int2_1025 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1090, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_1026 = torch.constant.int 1 + %int2_1027 = torch.constant.int 2 + %1091 = torch.aten.transpose.int %1088, %int1_1026, %int2_1027 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1091, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_1028 = torch.constant.float 0.000000e+00 + %false_1029 = torch.constant.bool false + %none_1030 = torch.constant.none + %1092:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1089, %1090, %1091, %float0.000000e00_1028, %false_1029, %327, %none_1030) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %1092#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_1031 = torch.constant.int 1 + %int2_1032 = torch.constant.int 2 + %1093 = torch.aten.transpose.int %1092#0, %int1_1031, %int2_1032 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1093, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_1033 = torch.constant.int 4 + %int4096_1034 = torch.constant.int 4096 + %1094 = torch.prim.ListConstruct %int4_1033, %298, %int4096_1034 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1095 = torch.aten.view %1093, %1094 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1095, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_1035 = torch.constant.int -2 + %int-1_1036 = torch.constant.int -1 + %1096 = torch.aten.transpose.int %24, %int-2_1035, %int-1_1036 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_1037 = torch.constant.int 5 + %1097 = torch.prims.convert_element_type %1096, %int5_1037 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_1038 = torch.constant.int 4096 + %1098 = torch.prim.ListConstruct %342, %int4096_1038 : (!torch.int, !torch.int) -> !torch.list %1099 = torch.aten.view %1095, %1098 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1099, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1100 = torch.aten.mm %1099, %1096 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1100, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_1008 = torch.constant.int 4 - %int4096_1009 = torch.constant.int 4096 - %1101 = torch.prim.ListConstruct %int4_1008, %306, %int4096_1009 : (!torch.int, !torch.int, !torch.int) -> !torch.list + torch.bind_symbolic_shape %1099, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1100 = torch.aten.mm %1099, %1097 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1100, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_1039 = torch.constant.int 4 + %int4096_1040 = torch.constant.int 4096 + %1101 = torch.prim.ListConstruct %int4_1039, %298, %int4096_1040 : (!torch.int, !torch.int, !torch.int) -> !torch.list %1102 = torch.aten.view %1100, %1101 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1102, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1010 = torch.constant.int -2 - %int-1_1011 = torch.constant.int -1 - %1103 = torch.aten.transpose.int %39, %int-2_1010, %int-1_1011 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_1012 = torch.constant.int 4 - %1104 = torch.aten.mul.int %int4_1012, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1013 = torch.constant.int 4096 - %1105 = torch.prim.ListConstruct %1104, %int4096_1013 : (!torch.int, !torch.int) -> !torch.list - %1106 = torch.aten.view %1095, %1105 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1106, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1107 = torch.aten.mm %1106, %1103 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %1107, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_1014 = torch.constant.int 4 - %int1024_1015 = torch.constant.int 1024 - %1108 = torch.prim.ListConstruct %int4_1014, %306, %int1024_1015 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1109 = torch.aten.view %1107, %1108 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %1109, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_1016 = torch.constant.int -2 - %int-1_1017 = torch.constant.int -1 - %1110 = torch.aten.transpose.int %40, %int-2_1016, %int-1_1017 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_1018 = torch.constant.int 4 - %1111 = torch.aten.mul.int %int4_1018, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1019 = torch.constant.int 4096 - %1112 = torch.prim.ListConstruct %1111, %int4096_1019 : (!torch.int, !torch.int) -> !torch.list - %1113 = torch.aten.view %1095, %1112 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1113, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1114 = torch.aten.mm %1113, %1110 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %1114, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_1020 = torch.constant.int 4 - %int1024_1021 = torch.constant.int 1024 - %1115 = torch.prim.ListConstruct %int4_1020, %306, %int1024_1021 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1116 = torch.aten.view %1114, %1115 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %1116, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_1022 = torch.constant.int 4 - %int32_1023 = torch.constant.int 32 - %int128_1024 = torch.constant.int 128 - %1117 = torch.prim.ListConstruct %int4_1022, %306, %int32_1023, %int128_1024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1118 = torch.aten.view %1102, %1117 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1118, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_1025 = torch.constant.int 4 - %int8_1026 = torch.constant.int 8 - %int128_1027 = torch.constant.int 128 - %1119 = torch.prim.ListConstruct %int4_1025, %306, %int8_1026, %int128_1027 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1120 = torch.aten.view %1109, %1119 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1120, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_1028 = torch.constant.int 4 - %int8_1029 = torch.constant.int 8 - %int128_1030 = torch.constant.int 128 - %1121 = torch.prim.ListConstruct %int4_1028, %306, %int8_1029, %int128_1030 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1122 = torch.aten.view %1116, %1121 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1122, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_1031 = torch.constant.int 131072 - %none_1032 = torch.constant.none - %none_1033 = torch.constant.none - %cpu_1034 = torch.constant.device "cpu" - %false_1035 = torch.constant.bool false - %1123 = torch.aten.arange %int131072_1031, %none_1032, %none_1033, %cpu_1034, %false_1035 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_1036 = torch.constant.int 0 - %int128_1037 = torch.constant.int 128 - %none_1038 = torch.constant.none - %none_1039 = torch.constant.none - %cpu_1040 = torch.constant.device "cpu" - %false_1041 = torch.constant.bool false - %1124 = torch.aten.arange.start %int0_1036, %int128_1037, %none_1038, %none_1039, %cpu_1040, %false_1041 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_1042 = torch.constant.int 2 - %1125 = torch.aten.floor_divide.Scalar %1124, %int2_1042 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_1043 = torch.constant.int 6 - %1126 = torch.prims.convert_element_type %1125, %int6_1043 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_1044 = torch.constant.int 128 - %1127 = torch.aten.div.Scalar %1126, %int128_1044 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_1045 = torch.constant.float 2.000000e+00 - %1128 = torch.aten.mul.Scalar %1127, %float2.000000e00_1045 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_1046 = torch.constant.float 5.000000e+05 - %1129 = torch.aten.pow.Scalar %float5.000000e05_1046, %1128 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %1130 = torch.aten.reciprocal %1129 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_1047 = torch.constant.float 1.000000e+00 - %1131 = torch.aten.mul.Scalar %1130, %float1.000000e00_1047 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + torch.bind_symbolic_shape %1102, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_1041 = torch.constant.int 1 + %1103 = torch.aten.add.Tensor %869, %1102, %int1_1041 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1103, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_1042 = torch.constant.int 6 + %1104 = torch.prims.convert_element_type %1103, %int6_1042 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1104, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_1043 = torch.constant.int 2 + %1105 = torch.aten.pow.Tensor_Scalar %1104, %int2_1043 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1105, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_1044 = torch.constant.int -1 + %1106 = torch.prim.ListConstruct %int-1_1044 : (!torch.int) -> !torch.list + %true_1045 = torch.constant.bool true + %none_1046 = torch.constant.none + %1107 = torch.aten.mean.dim %1105, %1106, %true_1045, %none_1046 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1107, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_1047 = torch.constant.float 9.9999997473787516E-6 %int1_1048 = torch.constant.int 1 - %1132 = torch.aten.unsqueeze %1123, %int1_1048 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_1049 = torch.constant.int 0 - %1133 = torch.aten.unsqueeze %1131, %int0_1049 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %1134 = torch.aten.mul.Tensor %1132, %1133 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_1050 = torch.constant.int 1 - %1135 = torch.aten.size.int %1102, %int1_1050 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_1051 = torch.constant.int 0 - %1136 = torch.aten.add.int %int0_1051, %1135 : !torch.int, !torch.int -> !torch.int - %int0_1052 = torch.constant.int 0 - %int0_1053 = torch.constant.int 0 - %int1_1054 = torch.constant.int 1 - %1137 = torch.aten.slice.Tensor %1134, %int0_1052, %int0_1053, %1136, %int1_1054 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1137, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_1055 = torch.constant.int 1 - %int0_1056 = torch.constant.int 0 - %int9223372036854775807_1057 = torch.constant.int 9223372036854775807 - %int1_1058 = torch.constant.int 1 - %1138 = torch.aten.slice.Tensor %1137, %int1_1055, %int0_1056, %int9223372036854775807_1057, %int1_1058 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1138, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_1059 = torch.constant.int 1 - %int0_1060 = torch.constant.int 0 - %int9223372036854775807_1061 = torch.constant.int 9223372036854775807 - %int1_1062 = torch.constant.int 1 - %1139 = torch.aten.slice.Tensor %1138, %int1_1059, %int0_1060, %int9223372036854775807_1061, %int1_1062 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1139, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_1063 = torch.constant.int 0 - %1140 = torch.aten.unsqueeze %1139, %int0_1063 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1140, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_1064 = torch.constant.int 1 - %int0_1065 = torch.constant.int 0 - %int9223372036854775807_1066 = torch.constant.int 9223372036854775807 - %int1_1067 = torch.constant.int 1 - %1141 = torch.aten.slice.Tensor %1140, %int1_1064, %int0_1065, %int9223372036854775807_1066, %int1_1067 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1141, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_1068 = torch.constant.int 2 - %int0_1069 = torch.constant.int 0 - %int9223372036854775807_1070 = torch.constant.int 9223372036854775807 - %int1_1071 = torch.constant.int 1 - %1142 = torch.aten.slice.Tensor %1141, %int2_1068, %int0_1069, %int9223372036854775807_1070, %int1_1071 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1142, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_1072 = torch.constant.int 4 - %int1_1073 = torch.constant.int 1 - %int1_1074 = torch.constant.int 1 - %1143 = torch.prim.ListConstruct %int4_1072, %int1_1073, %int1_1074 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1144 = torch.aten.repeat %1142, %1143 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %1144, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_1075 = torch.constant.int 6 - %1145 = torch.prims.convert_element_type %1118, %int6_1075 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %1145, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %1146 = torch_c.to_builtin_tensor %1145 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %1147 = torch_c.to_builtin_tensor %1144 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %1148 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%1146, %1147) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %1149 = torch_c.from_builtin_tensor %1148 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %1149, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_1076 = torch.constant.int 5 - %1150 = torch.prims.convert_element_type %1149, %int5_1076 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1150, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_1077 = torch.constant.int 131072 - %none_1078 = torch.constant.none - %none_1079 = torch.constant.none - %cpu_1080 = torch.constant.device "cpu" - %false_1081 = torch.constant.bool false - %1151 = torch.aten.arange %int131072_1077, %none_1078, %none_1079, %cpu_1080, %false_1081 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_1082 = torch.constant.int 0 - %int128_1083 = torch.constant.int 128 - %none_1084 = torch.constant.none - %none_1085 = torch.constant.none - %cpu_1086 = torch.constant.device "cpu" - %false_1087 = torch.constant.bool false - %1152 = torch.aten.arange.start %int0_1082, %int128_1083, %none_1084, %none_1085, %cpu_1086, %false_1087 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_1088 = torch.constant.int 2 - %1153 = torch.aten.floor_divide.Scalar %1152, %int2_1088 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_1089 = torch.constant.int 6 - %1154 = torch.prims.convert_element_type %1153, %int6_1089 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_1090 = torch.constant.int 128 - %1155 = torch.aten.div.Scalar %1154, %int128_1090 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_1091 = torch.constant.float 2.000000e+00 - %1156 = torch.aten.mul.Scalar %1155, %float2.000000e00_1091 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_1092 = torch.constant.float 5.000000e+05 - %1157 = torch.aten.pow.Scalar %float5.000000e05_1092, %1156 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %1158 = torch.aten.reciprocal %1157 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_1093 = torch.constant.float 1.000000e+00 - %1159 = torch.aten.mul.Scalar %1158, %float1.000000e00_1093 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_1094 = torch.constant.int 1 - %1160 = torch.aten.unsqueeze %1151, %int1_1094 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_1095 = torch.constant.int 0 - %1161 = torch.aten.unsqueeze %1159, %int0_1095 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %1162 = torch.aten.mul.Tensor %1160, %1161 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_1096 = torch.constant.int 1 - %1163 = torch.aten.size.int %1109, %int1_1096 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_1097 = torch.constant.int 0 - %1164 = torch.aten.add.int %int0_1097, %1163 : !torch.int, !torch.int -> !torch.int - %int0_1098 = torch.constant.int 0 - %int0_1099 = torch.constant.int 0 - %int1_1100 = torch.constant.int 1 - %1165 = torch.aten.slice.Tensor %1162, %int0_1098, %int0_1099, %1164, %int1_1100 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1165, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_1101 = torch.constant.int 1 - %int0_1102 = torch.constant.int 0 - %int9223372036854775807_1103 = torch.constant.int 9223372036854775807 - %int1_1104 = torch.constant.int 1 - %1166 = torch.aten.slice.Tensor %1165, %int1_1101, %int0_1102, %int9223372036854775807_1103, %int1_1104 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1166, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_1105 = torch.constant.int 1 - %int0_1106 = torch.constant.int 0 - %int9223372036854775807_1107 = torch.constant.int 9223372036854775807 - %int1_1108 = torch.constant.int 1 - %1167 = torch.aten.slice.Tensor %1166, %int1_1105, %int0_1106, %int9223372036854775807_1107, %int1_1108 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1167, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_1109 = torch.constant.int 0 - %1168 = torch.aten.unsqueeze %1167, %int0_1109 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1168, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_1110 = torch.constant.int 1 + %1108 = torch.aten.add.Scalar %1107, %float9.999990e-06_1047, %int1_1048 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1108, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1109 = torch.aten.rsqrt %1108 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1109, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1110 = torch.aten.mul.Tensor %1104, %1109 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1110, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_1049 = torch.constant.int 5 + %1111 = torch.prims.convert_element_type %1110, %int5_1049 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1111, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %1112 = torch.aten.mul.Tensor %25, %1111 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1112, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_1050 = torch.constant.int 5 + %1113 = torch.prims.convert_element_type %1112, %int5_1050 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1113, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_1051 = torch.constant.int -2 + %int-1_1052 = torch.constant.int -1 + %1114 = torch.aten.transpose.int %26, %int-2_1051, %int-1_1052 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_1053 = torch.constant.int 5 + %1115 = torch.prims.convert_element_type %1114, %int5_1053 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_1054 = torch.constant.int 4096 + %1116 = torch.prim.ListConstruct %342, %int4096_1054 : (!torch.int, !torch.int) -> !torch.list + %1117 = torch.aten.view %1113, %1116 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1117, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1118 = torch.aten.mm %1117, %1115 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %1118, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_1055 = torch.constant.int 4 + %int14336_1056 = torch.constant.int 14336 + %1119 = torch.prim.ListConstruct %int4_1055, %298, %int14336_1056 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1120 = torch.aten.view %1118, %1119 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1120, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %1121 = torch.aten.silu %1120 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1121, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_1057 = torch.constant.int -2 + %int-1_1058 = torch.constant.int -1 + %1122 = torch.aten.transpose.int %27, %int-2_1057, %int-1_1058 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_1059 = torch.constant.int 5 + %1123 = torch.prims.convert_element_type %1122, %int5_1059 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_1060 = torch.constant.int 4096 + %1124 = torch.prim.ListConstruct %342, %int4096_1060 : (!torch.int, !torch.int) -> !torch.list + %1125 = torch.aten.view %1113, %1124 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1125, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1126 = torch.aten.mm %1125, %1123 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %1126, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_1061 = torch.constant.int 4 + %int14336_1062 = torch.constant.int 14336 + %1127 = torch.prim.ListConstruct %int4_1061, %298, %int14336_1062 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1128 = torch.aten.view %1126, %1127 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1128, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %1129 = torch.aten.mul.Tensor %1121, %1128 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1129, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_1063 = torch.constant.int -2 + %int-1_1064 = torch.constant.int -1 + %1130 = torch.aten.transpose.int %28, %int-2_1063, %int-1_1064 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_1065 = torch.constant.int 5 + %1131 = torch.prims.convert_element_type %1130, %int5_1065 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_1066 = torch.constant.int 14336 + %1132 = torch.prim.ListConstruct %342, %int14336_1066 : (!torch.int, !torch.int) -> !torch.list + %1133 = torch.aten.view %1129, %1132 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %1133, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %1134 = torch.aten.mm %1133, %1131 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1134, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_1067 = torch.constant.int 4 + %int4096_1068 = torch.constant.int 4096 + %1135 = torch.prim.ListConstruct %int4_1067, %298, %int4096_1068 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1136 = torch.aten.view %1134, %1135 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1136, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_1069 = torch.constant.int 1 + %1137 = torch.aten.add.Tensor %1103, %1136, %int1_1069 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1137, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_1070 = torch.constant.int 6 + %1138 = torch.prims.convert_element_type %1137, %int6_1070 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1138, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_1071 = torch.constant.int 2 + %1139 = torch.aten.pow.Tensor_Scalar %1138, %int2_1071 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1139, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_1072 = torch.constant.int -1 + %1140 = torch.prim.ListConstruct %int-1_1072 : (!torch.int) -> !torch.list + %true_1073 = torch.constant.bool true + %none_1074 = torch.constant.none + %1141 = torch.aten.mean.dim %1139, %1140, %true_1073, %none_1074 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1141, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_1075 = torch.constant.float 9.9999997473787516E-6 + %int1_1076 = torch.constant.int 1 + %1142 = torch.aten.add.Scalar %1141, %float9.999990e-06_1075, %int1_1076 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1142, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1143 = torch.aten.rsqrt %1142 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1143, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1144 = torch.aten.mul.Tensor %1138, %1143 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1144, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_1077 = torch.constant.int 5 + %1145 = torch.prims.convert_element_type %1144, %int5_1077 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1145, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %1146 = torch.aten.mul.Tensor %29, %1145 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1146, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_1078 = torch.constant.int 5 + %1147 = torch.prims.convert_element_type %1146, %int5_1078 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1147, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_1079 = torch.constant.int -2 + %int-1_1080 = torch.constant.int -1 + %1148 = torch.aten.transpose.int %30, %int-2_1079, %int-1_1080 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_1081 = torch.constant.int 5 + %1149 = torch.prims.convert_element_type %1148, %int5_1081 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_1082 = torch.constant.int 4096 + %1150 = torch.prim.ListConstruct %342, %int4096_1082 : (!torch.int, !torch.int) -> !torch.list + %1151 = torch.aten.view %1147, %1150 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1151, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1152 = torch.aten.mm %1151, %1149 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1152, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_1083 = torch.constant.int 4 + %int4096_1084 = torch.constant.int 4096 + %1153 = torch.prim.ListConstruct %int4_1083, %298, %int4096_1084 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1154 = torch.aten.view %1152, %1153 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1154, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_1085 = torch.constant.int -2 + %int-1_1086 = torch.constant.int -1 + %1155 = torch.aten.transpose.int %31, %int-2_1085, %int-1_1086 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_1087 = torch.constant.int 5 + %1156 = torch.prims.convert_element_type %1155, %int5_1087 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_1088 = torch.constant.int 4096 + %1157 = torch.prim.ListConstruct %342, %int4096_1088 : (!torch.int, !torch.int) -> !torch.list + %1158 = torch.aten.view %1147, %1157 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1158, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1159 = torch.aten.mm %1158, %1156 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %1159, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_1089 = torch.constant.int 4 + %int1024_1090 = torch.constant.int 1024 + %1160 = torch.prim.ListConstruct %int4_1089, %298, %int1024_1090 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1161 = torch.aten.view %1159, %1160 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %1161, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_1091 = torch.constant.int -2 + %int-1_1092 = torch.constant.int -1 + %1162 = torch.aten.transpose.int %32, %int-2_1091, %int-1_1092 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_1093 = torch.constant.int 5 + %1163 = torch.prims.convert_element_type %1162, %int5_1093 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_1094 = torch.constant.int 4096 + %1164 = torch.prim.ListConstruct %342, %int4096_1094 : (!torch.int, !torch.int) -> !torch.list + %1165 = torch.aten.view %1147, %1164 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1165, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1166 = torch.aten.mm %1165, %1163 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %1166, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_1095 = torch.constant.int 4 + %int1024_1096 = torch.constant.int 1024 + %1167 = torch.prim.ListConstruct %int4_1095, %298, %int1024_1096 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1168 = torch.aten.view %1166, %1167 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %1168, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_1097 = torch.constant.int 4 + %int32_1098 = torch.constant.int 32 + %int128_1099 = torch.constant.int 128 + %1169 = torch.prim.ListConstruct %int4_1097, %298, %int32_1098, %int128_1099 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1170 = torch.aten.view %1154, %1169 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1170, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_1100 = torch.constant.int 4 + %int8_1101 = torch.constant.int 8 + %int128_1102 = torch.constant.int 128 + %1171 = torch.prim.ListConstruct %int4_1100, %298, %int8_1101, %int128_1102 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1172 = torch.aten.view %1161, %1171 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1172, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_1103 = torch.constant.int 4 + %int8_1104 = torch.constant.int 8 + %int128_1105 = torch.constant.int 128 + %1173 = torch.prim.ListConstruct %int4_1103, %298, %int8_1104, %int128_1105 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1174 = torch.aten.view %1168, %1173 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1174, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_1106 = torch.constant.int 131072 + %none_1107 = torch.constant.none + %none_1108 = torch.constant.none + %cpu_1109 = torch.constant.device "cpu" + %false_1110 = torch.constant.bool false + %1175 = torch.aten.arange %int131072_1106, %none_1107, %none_1108, %cpu_1109, %false_1110 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> %int0_1111 = torch.constant.int 0 - %int9223372036854775807_1112 = torch.constant.int 9223372036854775807 - %int1_1113 = torch.constant.int 1 - %1169 = torch.aten.slice.Tensor %1168, %int1_1110, %int0_1111, %int9223372036854775807_1112, %int1_1113 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1169, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_1114 = torch.constant.int 2 - %int0_1115 = torch.constant.int 0 - %int9223372036854775807_1116 = torch.constant.int 9223372036854775807 - %int1_1117 = torch.constant.int 1 - %1170 = torch.aten.slice.Tensor %1169, %int2_1114, %int0_1115, %int9223372036854775807_1116, %int1_1117 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1170, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_1118 = torch.constant.int 4 - %int1_1119 = torch.constant.int 1 - %int1_1120 = torch.constant.int 1 - %1171 = torch.prim.ListConstruct %int4_1118, %int1_1119, %int1_1120 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1172 = torch.aten.repeat %1170, %1171 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %1172, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_1121 = torch.constant.int 6 - %1173 = torch.prims.convert_element_type %1120, %int6_1121 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %1173, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %1174 = torch_c.to_builtin_tensor %1173 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %1175 = torch_c.to_builtin_tensor %1172 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %1176 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%1174, %1175) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %1177 = torch_c.from_builtin_tensor %1176 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %1177, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_1122 = torch.constant.int 5 - %1178 = torch.prims.convert_element_type %1177, %int5_1122 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1178, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_1123 = torch.constant.int 64 - %1179 = torch.aten.mul.Scalar %arg2, %int64_1123 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1179, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int128_1112 = torch.constant.int 128 + %int2_1113 = torch.constant.int 2 + %int4_1114 = torch.constant.int 4 + %none_1115 = torch.constant.none + %cpu_1116 = torch.constant.device "cpu" + %false_1117 = torch.constant.bool false + %1176 = torch.aten.arange.start_step %int0_1111, %int128_1112, %int2_1113, %int4_1114, %none_1115, %cpu_1116, %false_1117 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_1118 = torch.constant.int 6 + %1177 = torch.prims.convert_element_type %1176, %int6_1118 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_1119 = torch.constant.int 128 + %1178 = torch.aten.div.Scalar %1177, %int128_1119 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_1120 = torch.constant.float 5.000000e+05 + %1179 = torch.aten.pow.Scalar %float5.000000e05_1120, %1178 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1180 = torch.aten.reciprocal %1179 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_1121 = torch.constant.float 1.000000e+00 + %1181 = torch.aten.mul.Scalar %1180, %float1.000000e00_1121 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %1182 = torch.aten.reciprocal %1181 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_1122 = torch.constant.float 6.2831853071795862 + %1183 = torch.aten.mul.Scalar %1182, %float6.283190e00_1122 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_1123 = torch.constant.float 8.192000e+03 + %1184 = torch.aten.gt.Scalar %1183, %float8.192000e03_1123 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> %int8_1124 = torch.constant.int 8 - %int1_1125 = torch.constant.int 1 - %1180 = torch.aten.add.Scalar %1179, %int8_1124, %int1_1125 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1180, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_1126 = torch.constant.int 4 - %int32_1127 = torch.constant.int 32 - %int8_1128 = torch.constant.int 8 - %int128_1129 = torch.constant.int 128 - %1181 = torch.prim.ListConstruct %int4_1126, %398, %int32_1127, %int8_1128, %int128_1129 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1182 = torch.aten.view %1178, %1181 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1182, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_1130 = torch.constant.int 4 - %1183 = torch.aten.mul.int %int4_1130, %398 : !torch.int, !torch.int -> !torch.int - %int32_1131 = torch.constant.int 32 - %int8_1132 = torch.constant.int 8 - %int128_1133 = torch.constant.int 128 - %1184 = torch.prim.ListConstruct %1183, %int32_1131, %int8_1132, %int128_1133 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1185 = torch.aten.view %1182, %1184 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1185, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_1134 = torch.constant.int 4 - %1186 = torch.aten.mul.int %int4_1134, %398 : !torch.int, !torch.int -> !torch.int - %1187 = torch.prim.ListConstruct %1186 : (!torch.int) -> !torch.list - %1188 = torch.aten.view %1180, %1187 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1188, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_1135 = torch.constant.int 32 - %int2_1136 = torch.constant.int 2 - %int32_1137 = torch.constant.int 32 - %int8_1138 = torch.constant.int 8 - %int128_1139 = torch.constant.int 128 - %1189 = torch.prim.ListConstruct %389, %int32_1135, %int2_1136, %int32_1137, %int8_1138, %int128_1139 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1190 = torch.aten.view %1022, %1189 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1190, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_1140 = torch.constant.int 32 - %1191 = torch.aten.mul.int %389, %int32_1140 : !torch.int, !torch.int -> !torch.int - %int2_1141 = torch.constant.int 2 - %1192 = torch.aten.mul.int %1191, %int2_1141 : !torch.int, !torch.int -> !torch.int - %int32_1142 = torch.constant.int 32 - %int8_1143 = torch.constant.int 8 - %int128_1144 = torch.constant.int 128 - %1193 = torch.prim.ListConstruct %1192, %int32_1142, %int8_1143, %int128_1144 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1194 = torch.aten.view %1190, %1193 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1194, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %1195 = torch.prim.ListConstruct %1188 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_1145 = torch.constant.bool false - %1196 = torch.aten.index_put %1194, %1195, %1185, %false_1145 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1196, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_1146 = torch.constant.int 32 - %int2_1147 = torch.constant.int 2 - %int32_1148 = torch.constant.int 32 - %int8_1149 = torch.constant.int 8 - %int128_1150 = torch.constant.int 128 - %1197 = torch.prim.ListConstruct %389, %int32_1146, %int2_1147, %int32_1148, %int8_1149, %int128_1150 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1198 = torch.aten.view %1196, %1197 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1198, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_1151 = torch.constant.int 2097152 - %1199 = torch.prim.ListConstruct %389, %int2097152_1151 : (!torch.int, !torch.int) -> !torch.list - %1200 = torch.aten.view %1198, %1199 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1200, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_1152 = torch.constant.int 32 - %int2_1153 = torch.constant.int 2 - %int32_1154 = torch.constant.int 32 - %int8_1155 = torch.constant.int 8 - %int128_1156 = torch.constant.int 128 - %1201 = torch.prim.ListConstruct %389, %int32_1152, %int2_1153, %int32_1154, %int8_1155, %int128_1156 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1202 = torch.aten.view %1200, %1201 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1202, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_1157 = torch.constant.int 32 - %int8_1158 = torch.constant.int 8 - %int128_1159 = torch.constant.int 128 - %1203 = torch.prim.ListConstruct %1192, %int32_1157, %int8_1158, %int128_1159 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1204 = torch.aten.view %1202, %1203 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1204, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_1160 = torch.constant.int 4 - %int32_1161 = torch.constant.int 32 - %int8_1162 = torch.constant.int 8 - %int128_1163 = torch.constant.int 128 - %1205 = torch.prim.ListConstruct %int4_1160, %398, %int32_1161, %int8_1162, %int128_1163 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1206 = torch.aten.view %1122, %1205 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1206, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_1164 = torch.constant.int 4 - %1207 = torch.aten.mul.int %int4_1164, %398 : !torch.int, !torch.int -> !torch.int - %int32_1165 = torch.constant.int 32 - %int8_1166 = torch.constant.int 8 - %int128_1167 = torch.constant.int 128 - %1208 = torch.prim.ListConstruct %1207, %int32_1165, %int8_1166, %int128_1167 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1209 = torch.aten.view %1206, %1208 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1209, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %1185 = torch.aten.div.Scalar %1181, %int8_1124 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %1186 = torch.aten.where.self %1184, %1185, %1181 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1187 = torch.aten.reciprocal %1183 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_1125 = torch.constant.int 8192 + %1188 = torch.aten.mul.Scalar %1187, %int8192_1125 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_1126 = torch.constant.int 1 + %int1_1127 = torch.constant.int 1 + %1189 = torch.aten.sub.Scalar %1188, %int1_1126, %int1_1127 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_1128 = torch.constant.int 3 + %1190 = torch.aten.div.Scalar %1189, %int3_1128 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_1129 = torch.constant.int 1 + %int1_1130 = torch.constant.int 1 + %1191 = torch.aten.rsub.Scalar %1190, %int1_1129, %int1_1130 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %1192 = torch.aten.mul.Tensor %1191, %1186 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_1131 = torch.constant.int 8 + %1193 = torch.aten.div.Scalar %1192, %int8_1131 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %1194 = torch.aten.mul.Tensor %1190, %1186 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_1132 = torch.constant.int 1 + %1195 = torch.aten.add.Tensor %1193, %1194, %int1_1132 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_1133 = torch.constant.float 2.048000e+03 + %1196 = torch.aten.lt.Scalar %1183, %float2.048000e03_1133 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %1197 = torch.aten.bitwise_not %1196 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_1134 = torch.constant.float 8.192000e+03 + %1198 = torch.aten.gt.Scalar %1183, %float8.192000e03_1134 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %1199 = torch.aten.bitwise_not %1198 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %1200 = torch.aten.mul.Tensor %1197, %1199 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %1201 = torch.aten.where.self %1200, %1195, %1186 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1202 = torch.prim.ListConstruct %1201, %1201 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_1135 = torch.constant.int -1 + %1203 = torch.aten.cat %1202, %int-1_1135 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_1136 = torch.constant.int 6 + %1204 = torch.prims.convert_element_type %1203, %int6_1136 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_1137 = torch.constant.int 1 + %1205 = torch.aten.unsqueeze %1175, %int1_1137 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_1138 = torch.constant.int 6 + %1206 = torch.prims.convert_element_type %1205, %int6_1138 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_1139 = torch.constant.int 0 + %1207 = torch.aten.unsqueeze %1204, %int0_1139 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_1140 = torch.constant.int 6 + %1208 = torch.prims.convert_element_type %1207, %int6_1140 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %1209 = torch.aten.mul.Tensor %1206, %1208 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %1210 = torch.aten.cos %1209 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_1141 = torch.constant.int 5 + %1211 = torch.prims.convert_element_type %1210, %int5_1141 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %1212 = torch.aten.sin %1209 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_1142 = torch.constant.int 5 + %1213 = torch.prims.convert_element_type %1212, %int5_1142 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_1143 = torch.constant.int 0 + %int0_1144 = torch.constant.int 0 + %int1_1145 = torch.constant.int 1 + %1214 = torch.aten.slice.Tensor %1211, %int0_1143, %int0_1144, %298, %int1_1145 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1214, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_1146 = torch.constant.int 1 + %int0_1147 = torch.constant.int 0 + %int9223372036854775807_1148 = torch.constant.int 9223372036854775807 + %int1_1149 = torch.constant.int 1 + %1215 = torch.aten.slice.Tensor %1214, %int1_1146, %int0_1147, %int9223372036854775807_1148, %int1_1149 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1215, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_1150 = torch.constant.int 0 + %int0_1151 = torch.constant.int 0 + %int1_1152 = torch.constant.int 1 + %1216 = torch.aten.slice.Tensor %1213, %int0_1150, %int0_1151, %298, %int1_1152 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1216, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_1153 = torch.constant.int 1 + %int0_1154 = torch.constant.int 0 + %int9223372036854775807_1155 = torch.constant.int 9223372036854775807 + %int1_1156 = torch.constant.int 1 + %1217 = torch.aten.slice.Tensor %1216, %int1_1153, %int0_1154, %int9223372036854775807_1155, %int1_1156 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1217, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_1157 = torch.constant.int 0 + %1218 = torch.aten.unsqueeze %1215, %int0_1157 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1218, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_1158 = torch.constant.int 1 + %int0_1159 = torch.constant.int 0 + %int9223372036854775807_1160 = torch.constant.int 9223372036854775807 + %int1_1161 = torch.constant.int 1 + %1219 = torch.aten.slice.Tensor %1218, %int1_1158, %int0_1159, %int9223372036854775807_1160, %int1_1161 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1219, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_1162 = torch.constant.int 2 + %1220 = torch.aten.unsqueeze %1219, %int2_1162 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1220, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_1163 = torch.constant.int 3 + %int0_1164 = torch.constant.int 0 + %int9223372036854775807_1165 = torch.constant.int 9223372036854775807 + %int1_1166 = torch.constant.int 1 + %1221 = torch.aten.slice.Tensor %1220, %int3_1163, %int0_1164, %int9223372036854775807_1165, %int1_1166 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1221, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_1167 = torch.constant.int 4 %int1_1168 = torch.constant.int 1 %int1_1169 = torch.constant.int 1 - %1210 = torch.aten.add.Scalar %1180, %int1_1168, %int1_1169 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1210, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_1170 = torch.constant.int 4 - %1211 = torch.aten.mul.int %int4_1170, %398 : !torch.int, !torch.int -> !torch.int - %1212 = torch.prim.ListConstruct %1211 : (!torch.int) -> !torch.list - %1213 = torch.aten.view %1210, %1212 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1213, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %1214 = torch.prim.ListConstruct %1213 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_1171 = torch.constant.bool false - %1215 = torch.aten.index_put %1204, %1214, %1209, %false_1171 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1215, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_1172 = torch.constant.int 32 - %int2_1173 = torch.constant.int 2 - %int32_1174 = torch.constant.int 32 - %int8_1175 = torch.constant.int 8 - %int128_1176 = torch.constant.int 128 - %1216 = torch.prim.ListConstruct %389, %int32_1172, %int2_1173, %int32_1174, %int8_1175, %int128_1176 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1217 = torch.aten.view %1215, %1216 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1217, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_1177 = torch.constant.int 2097152 - %1218 = torch.prim.ListConstruct %389, %int2097152_1177 : (!torch.int, !torch.int) -> !torch.list - %1219 = torch.aten.view %1217, %1218 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1219, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_1178 = torch.constant.int -2 - %1220 = torch.aten.unsqueeze %1178, %int-2_1178 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1220, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_1179 = torch.constant.int 4 - %int8_1180 = torch.constant.int 8 + %int1_1170 = torch.constant.int 1 + %1222 = torch.prim.ListConstruct %int4_1167, %int1_1168, %int1_1169, %int1_1170 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1223 = torch.aten.repeat %1221, %1222 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %1223, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_1171 = torch.constant.int 0 + %1224 = torch.aten.unsqueeze %1217, %int0_1171 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1224, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_1172 = torch.constant.int 1 + %int0_1173 = torch.constant.int 0 + %int9223372036854775807_1174 = torch.constant.int 9223372036854775807 + %int1_1175 = torch.constant.int 1 + %1225 = torch.aten.slice.Tensor %1224, %int1_1172, %int0_1173, %int9223372036854775807_1174, %int1_1175 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1225, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_1176 = torch.constant.int 2 + %1226 = torch.aten.unsqueeze %1225, %int2_1176 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1226, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_1177 = torch.constant.int 3 + %int0_1178 = torch.constant.int 0 + %int9223372036854775807_1179 = torch.constant.int 9223372036854775807 + %int1_1180 = torch.constant.int 1 + %1227 = torch.aten.slice.Tensor %1226, %int3_1177, %int0_1178, %int9223372036854775807_1179, %int1_1180 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1227, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> %int4_1181 = torch.constant.int 4 - %int128_1182 = torch.constant.int 128 - %1221 = torch.prim.ListConstruct %int4_1179, %1163, %int8_1180, %int4_1181, %int128_1182 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1183 = torch.constant.bool false - %1222 = torch.aten.expand %1220, %1221, %false_1183 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1222, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1184 = torch.constant.int 0 - %1223 = torch.aten.clone %1222, %int0_1184 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1223, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_1185 = torch.constant.int 4 - %int32_1186 = torch.constant.int 32 - %int128_1187 = torch.constant.int 128 - %1224 = torch.prim.ListConstruct %int4_1185, %1163, %int32_1186, %int128_1187 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1225 = torch.aten._unsafe_view %1223, %1224 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1225, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_1188 = torch.constant.int -2 - %1226 = torch.aten.unsqueeze %1122, %int-2_1188 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1226, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_1189 = torch.constant.int 1 - %1227 = torch.aten.size.int %1116, %int1_1189 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_1190 = torch.constant.int 4 - %int8_1191 = torch.constant.int 8 - %int4_1192 = torch.constant.int 4 - %int128_1193 = torch.constant.int 128 - %1228 = torch.prim.ListConstruct %int4_1190, %1227, %int8_1191, %int4_1192, %int128_1193 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1194 = torch.constant.bool false - %1229 = torch.aten.expand %1226, %1228, %false_1194 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1229, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1195 = torch.constant.int 0 - %1230 = torch.aten.clone %1229, %int0_1195 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1230, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_1196 = torch.constant.int 4 - %int32_1197 = torch.constant.int 32 - %int128_1198 = torch.constant.int 128 - %1231 = torch.prim.ListConstruct %int4_1196, %1227, %int32_1197, %int128_1198 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1232 = torch.aten._unsafe_view %1230, %1231 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1232, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_1199 = torch.constant.int 1 - %int2_1200 = torch.constant.int 2 - %1233 = torch.aten.transpose.int %1150, %int1_1199, %int2_1200 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1233, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_1201 = torch.constant.int 1 + %int1_1182 = torch.constant.int 1 + %int1_1183 = torch.constant.int 1 + %int1_1184 = torch.constant.int 1 + %1228 = torch.prim.ListConstruct %int4_1181, %int1_1182, %int1_1183, %int1_1184 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1229 = torch.aten.repeat %1227, %1228 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %1229, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %1230 = torch.aten.mul.Tensor %1170, %1223 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1230, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_1185 = torch.constant.int 3 + %int0_1186 = torch.constant.int 0 + %int64_1187 = torch.constant.int 64 + %int1_1188 = torch.constant.int 1 + %1231 = torch.aten.slice.Tensor %1170, %int3_1185, %int0_1186, %int64_1187, %int1_1188 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %1231, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_1189 = torch.constant.int 3 + %int64_1190 = torch.constant.int 64 + %int9223372036854775807_1191 = torch.constant.int 9223372036854775807 + %int1_1192 = torch.constant.int 1 + %1232 = torch.aten.slice.Tensor %1170, %int3_1189, %int64_1190, %int9223372036854775807_1191, %int1_1192 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %1232, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %1233 = torch.aten.neg %1232 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %1233, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %1234 = torch.prim.ListConstruct %1233, %1231 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_1193 = torch.constant.int -1 + %1235 = torch.aten.cat %1234, %int-1_1193 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1235, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %1236 = torch.aten.mul.Tensor %1235, %1229 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1236, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_1194 = torch.constant.int 1 + %1237 = torch.aten.add.Tensor %1230, %1236, %int1_1194 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1237, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_1195 = torch.constant.int 131072 + %none_1196 = torch.constant.none + %none_1197 = torch.constant.none + %cpu_1198 = torch.constant.device "cpu" + %false_1199 = torch.constant.bool false + %1238 = torch.aten.arange %int131072_1195, %none_1196, %none_1197, %cpu_1198, %false_1199 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_1200 = torch.constant.int 0 + %int128_1201 = torch.constant.int 128 %int2_1202 = torch.constant.int 2 - %1234 = torch.aten.transpose.int %1225, %int1_1201, %int2_1202 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1234, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_1203 = torch.constant.int 1 - %int2_1204 = torch.constant.int 2 - %1235 = torch.aten.transpose.int %1232, %int1_1203, %int2_1204 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1235, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_1205 = torch.constant.float 0.000000e+00 - %true_1206 = torch.constant.bool true - %none_1207 = torch.constant.none - %none_1208 = torch.constant.none - %1236:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1233, %1234, %1235, %float0.000000e00_1205, %true_1206, %none_1207, %none_1208) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %1236#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_1209 = torch.constant.int 1 - %int2_1210 = torch.constant.int 2 - %1237 = torch.aten.transpose.int %1236#0, %int1_1209, %int2_1210 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1237, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_1211 = torch.constant.int 4 - %int4096_1212 = torch.constant.int 4096 - %1238 = torch.prim.ListConstruct %int4_1211, %1135, %int4096_1212 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1239 = torch.aten.view %1237, %1238 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1239, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1213 = torch.constant.int -2 - %int-1_1214 = torch.constant.int -1 - %1240 = torch.aten.transpose.int %41, %int-2_1213, %int-1_1214 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1215 = torch.constant.int 4 - %1241 = torch.aten.mul.int %int4_1215, %1135 : !torch.int, !torch.int -> !torch.int - %int4096_1216 = torch.constant.int 4096 - %1242 = torch.prim.ListConstruct %1241, %int4096_1216 : (!torch.int, !torch.int) -> !torch.list - %1243 = torch.aten.view %1239, %1242 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1243, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1244 = torch.aten.mm %1243, %1240 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1244, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_1217 = torch.constant.int 4 - %int4096_1218 = torch.constant.int 4096 - %1245 = torch.prim.ListConstruct %int4_1217, %1135, %int4096_1218 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1246 = torch.aten.view %1244, %1245 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1246, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int4_1203 = torch.constant.int 4 + %none_1204 = torch.constant.none + %cpu_1205 = torch.constant.device "cpu" + %false_1206 = torch.constant.bool false + %1239 = torch.aten.arange.start_step %int0_1200, %int128_1201, %int2_1202, %int4_1203, %none_1204, %cpu_1205, %false_1206 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_1207 = torch.constant.int 6 + %1240 = torch.prims.convert_element_type %1239, %int6_1207 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_1208 = torch.constant.int 128 + %1241 = torch.aten.div.Scalar %1240, %int128_1208 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_1209 = torch.constant.float 5.000000e+05 + %1242 = torch.aten.pow.Scalar %float5.000000e05_1209, %1241 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1243 = torch.aten.reciprocal %1242 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_1210 = torch.constant.float 1.000000e+00 + %1244 = torch.aten.mul.Scalar %1243, %float1.000000e00_1210 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %1245 = torch.aten.reciprocal %1244 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_1211 = torch.constant.float 6.2831853071795862 + %1246 = torch.aten.mul.Scalar %1245, %float6.283190e00_1211 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_1212 = torch.constant.float 8.192000e+03 + %1247 = torch.aten.gt.Scalar %1246, %float8.192000e03_1212 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_1213 = torch.constant.int 8 + %1248 = torch.aten.div.Scalar %1244, %int8_1213 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %1249 = torch.aten.where.self %1247, %1248, %1244 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1250 = torch.aten.reciprocal %1246 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_1214 = torch.constant.int 8192 + %1251 = torch.aten.mul.Scalar %1250, %int8192_1214 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_1215 = torch.constant.int 1 + %int1_1216 = torch.constant.int 1 + %1252 = torch.aten.sub.Scalar %1251, %int1_1215, %int1_1216 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_1217 = torch.constant.int 3 + %1253 = torch.aten.div.Scalar %1252, %int3_1217 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_1218 = torch.constant.int 1 %int1_1219 = torch.constant.int 1 - %1247 = torch.aten.add.Tensor %1085, %1246, %int1_1219 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1247, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_1220 = torch.constant.int 6 - %1248 = torch.prims.convert_element_type %1247, %int6_1220 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1248, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_1221 = torch.constant.int 2 - %1249 = torch.aten.pow.Tensor_Scalar %1248, %int2_1221 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1249, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_1222 = torch.constant.int -1 - %1250 = torch.prim.ListConstruct %int-1_1222 : (!torch.int) -> !torch.list - %true_1223 = torch.constant.bool true - %none_1224 = torch.constant.none - %1251 = torch.aten.mean.dim %1249, %1250, %true_1223, %none_1224 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1251, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_1225 = torch.constant.float 9.9999997473787516E-6 + %1254 = torch.aten.rsub.Scalar %1253, %int1_1218, %int1_1219 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %1255 = torch.aten.mul.Tensor %1254, %1249 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_1220 = torch.constant.int 8 + %1256 = torch.aten.div.Scalar %1255, %int8_1220 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %1257 = torch.aten.mul.Tensor %1253, %1249 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_1221 = torch.constant.int 1 + %1258 = torch.aten.add.Tensor %1256, %1257, %int1_1221 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_1222 = torch.constant.float 2.048000e+03 + %1259 = torch.aten.lt.Scalar %1246, %float2.048000e03_1222 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %1260 = torch.aten.bitwise_not %1259 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_1223 = torch.constant.float 8.192000e+03 + %1261 = torch.aten.gt.Scalar %1246, %float8.192000e03_1223 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %1262 = torch.aten.bitwise_not %1261 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %1263 = torch.aten.mul.Tensor %1260, %1262 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %1264 = torch.aten.where.self %1263, %1258, %1249 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1265 = torch.prim.ListConstruct %1264, %1264 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_1224 = torch.constant.int -1 + %1266 = torch.aten.cat %1265, %int-1_1224 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_1225 = torch.constant.int 6 + %1267 = torch.prims.convert_element_type %1266, %int6_1225 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> %int1_1226 = torch.constant.int 1 - %1252 = torch.aten.add.Scalar %1251, %float9.999990e-06_1225, %int1_1226 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1252, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1253 = torch.aten.rsqrt %1252 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1253, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1254 = torch.aten.mul.Tensor %1248, %1253 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1254, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1227 = torch.constant.int 5 - %1255 = torch.prims.convert_element_type %1254, %int5_1227 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1255, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %1256 = torch.aten.mul.Tensor %42, %1255 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1256, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1228 = torch.constant.int 5 - %1257 = torch.prims.convert_element_type %1256, %int5_1228 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1257, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1229 = torch.constant.int -2 - %int-1_1230 = torch.constant.int -1 - %1258 = torch.aten.transpose.int %43, %int-2_1229, %int-1_1230 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1231 = torch.constant.int 4 - %1259 = torch.aten.mul.int %int4_1231, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1232 = torch.constant.int 4096 - %1260 = torch.prim.ListConstruct %1259, %int4096_1232 : (!torch.int, !torch.int) -> !torch.list - %1261 = torch.aten.view %1257, %1260 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1261, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1262 = torch.aten.mm %1261, %1258 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1262, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_1233 = torch.constant.int 4 - %int14336_1234 = torch.constant.int 14336 - %1263 = torch.prim.ListConstruct %int4_1233, %306, %int14336_1234 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1264 = torch.aten.view %1262, %1263 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1264, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %1265 = torch.aten.silu %1264 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1265, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_1235 = torch.constant.int -2 - %int-1_1236 = torch.constant.int -1 - %1266 = torch.aten.transpose.int %44, %int-2_1235, %int-1_1236 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1237 = torch.constant.int 4 - %1267 = torch.aten.mul.int %int4_1237, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1238 = torch.constant.int 4096 - %1268 = torch.prim.ListConstruct %1267, %int4096_1238 : (!torch.int, !torch.int) -> !torch.list - %1269 = torch.aten.view %1257, %1268 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1269, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1270 = torch.aten.mm %1269, %1266 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1270, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_1239 = torch.constant.int 4 - %int14336_1240 = torch.constant.int 14336 - %1271 = torch.prim.ListConstruct %int4_1239, %306, %int14336_1240 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1272 = torch.aten.view %1270, %1271 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1272, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %1273 = torch.aten.mul.Tensor %1265, %1272 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1273, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_1241 = torch.constant.int -2 - %int-1_1242 = torch.constant.int -1 - %1274 = torch.aten.transpose.int %45, %int-2_1241, %int-1_1242 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_1243 = torch.constant.int 1 - %1275 = torch.aten.size.int %1264, %int1_1243 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_1244 = torch.constant.int 4 - %1276 = torch.aten.mul.int %int4_1244, %1275 : !torch.int, !torch.int -> !torch.int - %int14336_1245 = torch.constant.int 14336 - %1277 = torch.prim.ListConstruct %1276, %int14336_1245 : (!torch.int, !torch.int) -> !torch.list - %1278 = torch.aten.view %1273, %1277 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1278, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %1279 = torch.aten.mm %1278, %1274 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1279, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_1246 = torch.constant.int 4 - %int4096_1247 = torch.constant.int 4096 - %1280 = torch.prim.ListConstruct %int4_1246, %1275, %int4096_1247 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1281 = torch.aten.view %1279, %1280 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1281, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_1248 = torch.constant.int 1 - %1282 = torch.aten.add.Tensor %1247, %1281, %int1_1248 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1282, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_1249 = torch.constant.int 6 - %1283 = torch.prims.convert_element_type %1282, %int6_1249 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1283, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_1250 = torch.constant.int 2 - %1284 = torch.aten.pow.Tensor_Scalar %1283, %int2_1250 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1284, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_1251 = torch.constant.int -1 - %1285 = torch.prim.ListConstruct %int-1_1251 : (!torch.int) -> !torch.list - %true_1252 = torch.constant.bool true - %none_1253 = torch.constant.none - %1286 = torch.aten.mean.dim %1284, %1285, %true_1252, %none_1253 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1286, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_1254 = torch.constant.float 9.9999997473787516E-6 + %1268 = torch.aten.unsqueeze %1238, %int1_1226 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_1227 = torch.constant.int 6 + %1269 = torch.prims.convert_element_type %1268, %int6_1227 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_1228 = torch.constant.int 0 + %1270 = torch.aten.unsqueeze %1267, %int0_1228 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_1229 = torch.constant.int 6 + %1271 = torch.prims.convert_element_type %1270, %int6_1229 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %1272 = torch.aten.mul.Tensor %1269, %1271 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %1273 = torch.aten.cos %1272 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_1230 = torch.constant.int 5 + %1274 = torch.prims.convert_element_type %1273, %int5_1230 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %1275 = torch.aten.sin %1272 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_1231 = torch.constant.int 5 + %1276 = torch.prims.convert_element_type %1275, %int5_1231 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_1232 = torch.constant.int 0 + %int0_1233 = torch.constant.int 0 + %int1_1234 = torch.constant.int 1 + %1277 = torch.aten.slice.Tensor %1274, %int0_1232, %int0_1233, %298, %int1_1234 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1277, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_1235 = torch.constant.int 1 + %int0_1236 = torch.constant.int 0 + %int9223372036854775807_1237 = torch.constant.int 9223372036854775807 + %int1_1238 = torch.constant.int 1 + %1278 = torch.aten.slice.Tensor %1277, %int1_1235, %int0_1236, %int9223372036854775807_1237, %int1_1238 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1278, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_1239 = torch.constant.int 0 + %int0_1240 = torch.constant.int 0 + %int1_1241 = torch.constant.int 1 + %1279 = torch.aten.slice.Tensor %1276, %int0_1239, %int0_1240, %298, %int1_1241 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1279, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_1242 = torch.constant.int 1 + %int0_1243 = torch.constant.int 0 + %int9223372036854775807_1244 = torch.constant.int 9223372036854775807 + %int1_1245 = torch.constant.int 1 + %1280 = torch.aten.slice.Tensor %1279, %int1_1242, %int0_1243, %int9223372036854775807_1244, %int1_1245 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1280, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_1246 = torch.constant.int 0 + %1281 = torch.aten.unsqueeze %1278, %int0_1246 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1281, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_1247 = torch.constant.int 1 + %int0_1248 = torch.constant.int 0 + %int9223372036854775807_1249 = torch.constant.int 9223372036854775807 + %int1_1250 = torch.constant.int 1 + %1282 = torch.aten.slice.Tensor %1281, %int1_1247, %int0_1248, %int9223372036854775807_1249, %int1_1250 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1282, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_1251 = torch.constant.int 2 + %1283 = torch.aten.unsqueeze %1282, %int2_1251 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1283, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_1252 = torch.constant.int 3 + %int0_1253 = torch.constant.int 0 + %int9223372036854775807_1254 = torch.constant.int 9223372036854775807 %int1_1255 = torch.constant.int 1 - %1287 = torch.aten.add.Scalar %1286, %float9.999990e-06_1254, %int1_1255 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1287, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1288 = torch.aten.rsqrt %1287 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1288, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1289 = torch.aten.mul.Tensor %1283, %1288 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1289, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1256 = torch.constant.int 5 - %1290 = torch.prims.convert_element_type %1289, %int5_1256 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1290, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %1291 = torch.aten.mul.Tensor %46, %1290 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1291, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1257 = torch.constant.int 5 - %1292 = torch.prims.convert_element_type %1291, %int5_1257 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1292, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1258 = torch.constant.int -2 - %int-1_1259 = torch.constant.int -1 - %1293 = torch.aten.transpose.int %47, %int-2_1258, %int-1_1259 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1260 = torch.constant.int 4 - %1294 = torch.aten.mul.int %int4_1260, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1261 = torch.constant.int 4096 - %1295 = torch.prim.ListConstruct %1294, %int4096_1261 : (!torch.int, !torch.int) -> !torch.list - %1296 = torch.aten.view %1292, %1295 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1296, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1297 = torch.aten.mm %1296, %1293 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1297, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_1262 = torch.constant.int 4 - %int4096_1263 = torch.constant.int 4096 - %1298 = torch.prim.ListConstruct %int4_1262, %306, %int4096_1263 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1299 = torch.aten.view %1297, %1298 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1299, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1264 = torch.constant.int -2 - %int-1_1265 = torch.constant.int -1 - %1300 = torch.aten.transpose.int %48, %int-2_1264, %int-1_1265 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_1266 = torch.constant.int 4 - %1301 = torch.aten.mul.int %int4_1266, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1267 = torch.constant.int 4096 - %1302 = torch.prim.ListConstruct %1301, %int4096_1267 : (!torch.int, !torch.int) -> !torch.list - %1303 = torch.aten.view %1292, %1302 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1303, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1304 = torch.aten.mm %1303, %1300 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %1304, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_1268 = torch.constant.int 4 - %int1024_1269 = torch.constant.int 1024 - %1305 = torch.prim.ListConstruct %int4_1268, %306, %int1024_1269 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1306 = torch.aten.view %1304, %1305 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %1306, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_1270 = torch.constant.int -2 - %int-1_1271 = torch.constant.int -1 - %1307 = torch.aten.transpose.int %49, %int-2_1270, %int-1_1271 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_1272 = torch.constant.int 4 - %1308 = torch.aten.mul.int %int4_1272, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1273 = torch.constant.int 4096 - %1309 = torch.prim.ListConstruct %1308, %int4096_1273 : (!torch.int, !torch.int) -> !torch.list - %1310 = torch.aten.view %1292, %1309 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1310, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1311 = torch.aten.mm %1310, %1307 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %1311, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_1274 = torch.constant.int 4 - %int1024_1275 = torch.constant.int 1024 - %1312 = torch.prim.ListConstruct %int4_1274, %306, %int1024_1275 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1313 = torch.aten.view %1311, %1312 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %1313, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_1276 = torch.constant.int 4 - %int32_1277 = torch.constant.int 32 - %int128_1278 = torch.constant.int 128 - %1314 = torch.prim.ListConstruct %int4_1276, %306, %int32_1277, %int128_1278 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1315 = torch.aten.view %1299, %1314 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1315, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_1279 = torch.constant.int 4 - %int8_1280 = torch.constant.int 8 - %int128_1281 = torch.constant.int 128 - %1316 = torch.prim.ListConstruct %int4_1279, %306, %int8_1280, %int128_1281 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1317 = torch.aten.view %1306, %1316 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1317, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_1282 = torch.constant.int 4 - %int8_1283 = torch.constant.int 8 - %int128_1284 = torch.constant.int 128 - %1318 = torch.prim.ListConstruct %int4_1282, %306, %int8_1283, %int128_1284 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1319 = torch.aten.view %1313, %1318 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1319, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_1285 = torch.constant.int 131072 - %none_1286 = torch.constant.none - %none_1287 = torch.constant.none - %cpu_1288 = torch.constant.device "cpu" - %false_1289 = torch.constant.bool false - %1320 = torch.aten.arange %int131072_1285, %none_1286, %none_1287, %cpu_1288, %false_1289 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_1290 = torch.constant.int 0 - %int128_1291 = torch.constant.int 128 - %none_1292 = torch.constant.none - %none_1293 = torch.constant.none - %cpu_1294 = torch.constant.device "cpu" - %false_1295 = torch.constant.bool false - %1321 = torch.aten.arange.start %int0_1290, %int128_1291, %none_1292, %none_1293, %cpu_1294, %false_1295 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_1296 = torch.constant.int 2 - %1322 = torch.aten.floor_divide.Scalar %1321, %int2_1296 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_1297 = torch.constant.int 6 - %1323 = torch.prims.convert_element_type %1322, %int6_1297 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_1298 = torch.constant.int 128 - %1324 = torch.aten.div.Scalar %1323, %int128_1298 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_1299 = torch.constant.float 2.000000e+00 - %1325 = torch.aten.mul.Scalar %1324, %float2.000000e00_1299 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_1300 = torch.constant.float 5.000000e+05 - %1326 = torch.aten.pow.Scalar %float5.000000e05_1300, %1325 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %1327 = torch.aten.reciprocal %1326 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_1301 = torch.constant.float 1.000000e+00 - %1328 = torch.aten.mul.Scalar %1327, %float1.000000e00_1301 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_1302 = torch.constant.int 1 - %1329 = torch.aten.unsqueeze %1320, %int1_1302 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_1303 = torch.constant.int 0 - %1330 = torch.aten.unsqueeze %1328, %int0_1303 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %1331 = torch.aten.mul.Tensor %1329, %1330 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_1304 = torch.constant.int 1 - %1332 = torch.aten.size.int %1299, %int1_1304 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_1305 = torch.constant.int 0 - %1333 = torch.aten.add.int %int0_1305, %1332 : !torch.int, !torch.int -> !torch.int - %int0_1306 = torch.constant.int 0 - %int0_1307 = torch.constant.int 0 - %int1_1308 = torch.constant.int 1 - %1334 = torch.aten.slice.Tensor %1331, %int0_1306, %int0_1307, %1333, %int1_1308 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1334, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_1309 = torch.constant.int 1 - %int0_1310 = torch.constant.int 0 - %int9223372036854775807_1311 = torch.constant.int 9223372036854775807 - %int1_1312 = torch.constant.int 1 - %1335 = torch.aten.slice.Tensor %1334, %int1_1309, %int0_1310, %int9223372036854775807_1311, %int1_1312 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1335, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_1313 = torch.constant.int 1 - %int0_1314 = torch.constant.int 0 - %int9223372036854775807_1315 = torch.constant.int 9223372036854775807 - %int1_1316 = torch.constant.int 1 - %1336 = torch.aten.slice.Tensor %1335, %int1_1313, %int0_1314, %int9223372036854775807_1315, %int1_1316 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1336, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_1317 = torch.constant.int 0 - %1337 = torch.aten.unsqueeze %1336, %int0_1317 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1337, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_1318 = torch.constant.int 1 - %int0_1319 = torch.constant.int 0 - %int9223372036854775807_1320 = torch.constant.int 9223372036854775807 - %int1_1321 = torch.constant.int 1 - %1338 = torch.aten.slice.Tensor %1337, %int1_1318, %int0_1319, %int9223372036854775807_1320, %int1_1321 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1338, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_1322 = torch.constant.int 2 - %int0_1323 = torch.constant.int 0 - %int9223372036854775807_1324 = torch.constant.int 9223372036854775807 + %1284 = torch.aten.slice.Tensor %1283, %int3_1252, %int0_1253, %int9223372036854775807_1254, %int1_1255 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1284, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_1256 = torch.constant.int 4 + %int1_1257 = torch.constant.int 1 + %int1_1258 = torch.constant.int 1 + %int1_1259 = torch.constant.int 1 + %1285 = torch.prim.ListConstruct %int4_1256, %int1_1257, %int1_1258, %int1_1259 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1286 = torch.aten.repeat %1284, %1285 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %1286, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_1260 = torch.constant.int 0 + %1287 = torch.aten.unsqueeze %1280, %int0_1260 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1287, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_1261 = torch.constant.int 1 + %int0_1262 = torch.constant.int 0 + %int9223372036854775807_1263 = torch.constant.int 9223372036854775807 + %int1_1264 = torch.constant.int 1 + %1288 = torch.aten.slice.Tensor %1287, %int1_1261, %int0_1262, %int9223372036854775807_1263, %int1_1264 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1288, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_1265 = torch.constant.int 2 + %1289 = torch.aten.unsqueeze %1288, %int2_1265 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1289, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_1266 = torch.constant.int 3 + %int0_1267 = torch.constant.int 0 + %int9223372036854775807_1268 = torch.constant.int 9223372036854775807 + %int1_1269 = torch.constant.int 1 + %1290 = torch.aten.slice.Tensor %1289, %int3_1266, %int0_1267, %int9223372036854775807_1268, %int1_1269 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1290, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_1270 = torch.constant.int 4 + %int1_1271 = torch.constant.int 1 + %int1_1272 = torch.constant.int 1 + %int1_1273 = torch.constant.int 1 + %1291 = torch.prim.ListConstruct %int4_1270, %int1_1271, %int1_1272, %int1_1273 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1292 = torch.aten.repeat %1290, %1291 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %1292, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %1293 = torch.aten.mul.Tensor %1172, %1286 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1293, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_1274 = torch.constant.int 3 + %int0_1275 = torch.constant.int 0 + %int64_1276 = torch.constant.int 64 + %int1_1277 = torch.constant.int 1 + %1294 = torch.aten.slice.Tensor %1172, %int3_1274, %int0_1275, %int64_1276, %int1_1277 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %1294, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_1278 = torch.constant.int 3 + %int64_1279 = torch.constant.int 64 + %int9223372036854775807_1280 = torch.constant.int 9223372036854775807 + %int1_1281 = torch.constant.int 1 + %1295 = torch.aten.slice.Tensor %1172, %int3_1278, %int64_1279, %int9223372036854775807_1280, %int1_1281 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %1295, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %1296 = torch.aten.neg %1295 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %1296, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %1297 = torch.prim.ListConstruct %1296, %1294 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_1282 = torch.constant.int -1 + %1298 = torch.aten.cat %1297, %int-1_1282 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1298, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %1299 = torch.aten.mul.Tensor %1298, %1292 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1299, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_1283 = torch.constant.int 1 + %1300 = torch.aten.add.Tensor %1293, %1299, %int1_1283 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1300, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_1284 = torch.constant.int 32 + %1301 = torch.aten.mul.Scalar %arg2, %int32_1284 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1301, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int3_1285 = torch.constant.int 3 + %int1_1286 = torch.constant.int 1 + %1302 = torch.aten.add.Scalar %1301, %int3_1285, %int1_1286 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1302, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_1287 = torch.constant.int 2 + %1303 = torch.aten.mul.Scalar %1302, %int2_1287 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1303, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_1288 = torch.constant.int 0 + %int1_1289 = torch.constant.int 1 + %1304 = torch.aten.add.Scalar %1303, %int0_1288, %int1_1289 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1304, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %1305 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %1306 = torch.aten.view %1304, %1305 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %1306, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_1290 = torch.constant.int 4 + %int32_1291 = torch.constant.int 32 + %int8_1292 = torch.constant.int 8 + %int128_1293 = torch.constant.int 128 + %1307 = torch.prim.ListConstruct %int4_1290, %296, %int32_1291, %int8_1292, %int128_1293 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1308 = torch.aten.view %1300, %1307 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1308, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_1294 = torch.constant.int 32 + %int8_1295 = torch.constant.int 8 + %int128_1296 = torch.constant.int 128 + %1309 = torch.prim.ListConstruct %504, %int32_1294, %int8_1295, %int128_1296 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1310 = torch.aten.view %1308, %1309 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %1310, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_1297 = torch.constant.int 1 + %int2_1298 = torch.constant.int 2 + %1311 = torch.aten.transpose.int %1310, %int1_1297, %int2_1298 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1311, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_1299 = torch.constant.int 5 + %1312 = torch.prims.convert_element_type %1311, %int5_1299 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1312, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_1300 = torch.constant.int 32 + %int2_1301 = torch.constant.int 2 + %int8_1302 = torch.constant.int 8 + %int32_1303 = torch.constant.int 32 + %int128_1304 = torch.constant.int 128 + %1313 = torch.prim.ListConstruct %297, %int32_1300, %int2_1301, %int8_1302, %int32_1303, %int128_1304 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1314 = torch.aten.view %1076, %1313 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1314, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_1305 = torch.constant.int 8 + %int32_1306 = torch.constant.int 32 + %int128_1307 = torch.constant.int 128 + %1315 = torch.prim.ListConstruct %497, %int8_1305, %int32_1306, %int128_1307 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1316 = torch.aten.view %1314, %1315 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1316, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %1317 = torch.prim.ListConstruct %1306 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_1308 = torch.constant.bool false + %1318 = torch.aten.index_put %1316, %1317, %1312, %false_1308 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1318, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_1309 = torch.constant.int 32 + %int2_1310 = torch.constant.int 2 + %int8_1311 = torch.constant.int 8 + %int32_1312 = torch.constant.int 32 + %int128_1313 = torch.constant.int 128 + %1319 = torch.prim.ListConstruct %297, %int32_1309, %int2_1310, %int8_1311, %int32_1312, %int128_1313 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1320 = torch.aten.view %1318, %1319 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1320, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_1314 = torch.constant.int 2097152 + %1321 = torch.prim.ListConstruct %297, %int2097152_1314 : (!torch.int, !torch.int) -> !torch.list + %1322 = torch.aten.view %1320, %1321 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1322, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_1315 = torch.constant.int 32 + %int2_1316 = torch.constant.int 2 + %int8_1317 = torch.constant.int 8 + %int32_1318 = torch.constant.int 32 + %int128_1319 = torch.constant.int 128 + %1323 = torch.prim.ListConstruct %297, %int32_1315, %int2_1316, %int8_1317, %int32_1318, %int128_1319 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1324 = torch.aten.view %1322, %1323 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1324, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_1320 = torch.constant.int 8 + %int32_1321 = torch.constant.int 32 + %int128_1322 = torch.constant.int 128 + %1325 = torch.prim.ListConstruct %497, %int8_1320, %int32_1321, %int128_1322 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1326 = torch.aten.view %1324, %1325 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1326, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_1323 = torch.constant.int 32 + %1327 = torch.aten.mul.Scalar %arg2, %int32_1323 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1327, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int3_1324 = torch.constant.int 3 %int1_1325 = torch.constant.int 1 - %1339 = torch.aten.slice.Tensor %1338, %int2_1322, %int0_1323, %int9223372036854775807_1324, %int1_1325 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1339, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_1326 = torch.constant.int 4 + %1328 = torch.aten.add.Scalar %1327, %int3_1324, %int1_1325 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1328, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_1326 = torch.constant.int 2 + %1329 = torch.aten.mul.Scalar %1328, %int2_1326 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1329, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> %int1_1327 = torch.constant.int 1 %int1_1328 = torch.constant.int 1 - %1340 = torch.prim.ListConstruct %int4_1326, %int1_1327, %int1_1328 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1341 = torch.aten.repeat %1339, %1340 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %1341, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_1329 = torch.constant.int 6 - %1342 = torch.prims.convert_element_type %1315, %int6_1329 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %1342, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %1343 = torch_c.to_builtin_tensor %1342 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %1344 = torch_c.to_builtin_tensor %1341 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %1345 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%1343, %1344) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %1346 = torch_c.from_builtin_tensor %1345 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %1346, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_1330 = torch.constant.int 5 - %1347 = torch.prims.convert_element_type %1346, %int5_1330 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1347, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_1331 = torch.constant.int 131072 - %none_1332 = torch.constant.none - %none_1333 = torch.constant.none - %cpu_1334 = torch.constant.device "cpu" - %false_1335 = torch.constant.bool false - %1348 = torch.aten.arange %int131072_1331, %none_1332, %none_1333, %cpu_1334, %false_1335 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_1336 = torch.constant.int 0 - %int128_1337 = torch.constant.int 128 - %none_1338 = torch.constant.none - %none_1339 = torch.constant.none - %cpu_1340 = torch.constant.device "cpu" - %false_1341 = torch.constant.bool false - %1349 = torch.aten.arange.start %int0_1336, %int128_1337, %none_1338, %none_1339, %cpu_1340, %false_1341 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_1342 = torch.constant.int 2 - %1350 = torch.aten.floor_divide.Scalar %1349, %int2_1342 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_1343 = torch.constant.int 6 - %1351 = torch.prims.convert_element_type %1350, %int6_1343 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> + %1330 = torch.aten.add.Scalar %1329, %int1_1327, %int1_1328 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1330, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %1331 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %1332 = torch.aten.view %1330, %1331 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %1332, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_1329 = torch.constant.int 4 + %int32_1330 = torch.constant.int 32 + %int8_1331 = torch.constant.int 8 + %int128_1332 = torch.constant.int 128 + %1333 = torch.prim.ListConstruct %int4_1329, %296, %int32_1330, %int8_1331, %int128_1332 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1334 = torch.aten.view %1174, %1333 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1334, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_1333 = torch.constant.int 32 + %int8_1334 = torch.constant.int 8 + %int128_1335 = torch.constant.int 128 + %1335 = torch.prim.ListConstruct %504, %int32_1333, %int8_1334, %int128_1335 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1336 = torch.aten.view %1334, %1335 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %1336, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_1336 = torch.constant.int 1 + %int2_1337 = torch.constant.int 2 + %1337 = torch.aten.transpose.int %1336, %int1_1336, %int2_1337 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1337, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_1338 = torch.constant.int 5 + %1338 = torch.prims.convert_element_type %1337, %int5_1338 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1338, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %1339 = torch.prim.ListConstruct %1332 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_1339 = torch.constant.bool false + %1340 = torch.aten.index_put %1326, %1339, %1338, %false_1339 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1340, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_1340 = torch.constant.int 32 + %int2_1341 = torch.constant.int 2 + %int8_1342 = torch.constant.int 8 + %int32_1343 = torch.constant.int 32 %int128_1344 = torch.constant.int 128 - %1352 = torch.aten.div.Scalar %1351, %int128_1344 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_1345 = torch.constant.float 2.000000e+00 - %1353 = torch.aten.mul.Scalar %1352, %float2.000000e00_1345 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_1346 = torch.constant.float 5.000000e+05 - %1354 = torch.aten.pow.Scalar %float5.000000e05_1346, %1353 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %1355 = torch.aten.reciprocal %1354 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_1347 = torch.constant.float 1.000000e+00 - %1356 = torch.aten.mul.Scalar %1355, %float1.000000e00_1347 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_1348 = torch.constant.int 1 - %1357 = torch.aten.unsqueeze %1348, %int1_1348 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_1349 = torch.constant.int 0 - %1358 = torch.aten.unsqueeze %1356, %int0_1349 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %1359 = torch.aten.mul.Tensor %1357, %1358 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_1350 = torch.constant.int 1 - %1360 = torch.aten.size.int %1306, %int1_1350 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_1351 = torch.constant.int 0 - %1361 = torch.aten.add.int %int0_1351, %1360 : !torch.int, !torch.int -> !torch.int + %1341 = torch.prim.ListConstruct %297, %int32_1340, %int2_1341, %int8_1342, %int32_1343, %int128_1344 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1342 = torch.aten.view %1340, %1341 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1342, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_1345 = torch.constant.int 2097152 + %1343 = torch.prim.ListConstruct %297, %int2097152_1345 : (!torch.int, !torch.int) -> !torch.list + %1344 = torch.aten.view %1342, %1343 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1344, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_1346 = torch.constant.int -2 + %1345 = torch.aten.unsqueeze %1300, %int-2_1346 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1345, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_1347 = torch.constant.int 4 + %int8_1348 = torch.constant.int 8 + %int4_1349 = torch.constant.int 4 + %int128_1350 = torch.constant.int 128 + %1346 = torch.prim.ListConstruct %int4_1347, %298, %int8_1348, %int4_1349, %int128_1350 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_1351 = torch.constant.bool false + %1347 = torch.aten.expand %1345, %1346, %false_1351 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1347, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int0_1352 = torch.constant.int 0 - %int0_1353 = torch.constant.int 0 - %int1_1354 = torch.constant.int 1 - %1362 = torch.aten.slice.Tensor %1359, %int0_1352, %int0_1353, %1361, %int1_1354 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1362, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_1355 = torch.constant.int 1 - %int0_1356 = torch.constant.int 0 - %int9223372036854775807_1357 = torch.constant.int 9223372036854775807 - %int1_1358 = torch.constant.int 1 - %1363 = torch.aten.slice.Tensor %1362, %int1_1355, %int0_1356, %int9223372036854775807_1357, %int1_1358 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1363, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_1359 = torch.constant.int 1 - %int0_1360 = torch.constant.int 0 - %int9223372036854775807_1361 = torch.constant.int 9223372036854775807 - %int1_1362 = torch.constant.int 1 - %1364 = torch.aten.slice.Tensor %1363, %int1_1359, %int0_1360, %int9223372036854775807_1361, %int1_1362 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1364, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_1363 = torch.constant.int 0 - %1365 = torch.aten.unsqueeze %1364, %int0_1363 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1365, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_1364 = torch.constant.int 1 - %int0_1365 = torch.constant.int 0 - %int9223372036854775807_1366 = torch.constant.int 9223372036854775807 - %int1_1367 = torch.constant.int 1 - %1366 = torch.aten.slice.Tensor %1365, %int1_1364, %int0_1365, %int9223372036854775807_1366, %int1_1367 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1366, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_1368 = torch.constant.int 2 - %int0_1369 = torch.constant.int 0 - %int9223372036854775807_1370 = torch.constant.int 9223372036854775807 - %int1_1371 = torch.constant.int 1 - %1367 = torch.aten.slice.Tensor %1366, %int2_1368, %int0_1369, %int9223372036854775807_1370, %int1_1371 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1367, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_1372 = torch.constant.int 4 - %int1_1373 = torch.constant.int 1 - %int1_1374 = torch.constant.int 1 - %1368 = torch.prim.ListConstruct %int4_1372, %int1_1373, %int1_1374 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1369 = torch.aten.repeat %1367, %1368 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %1369, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_1375 = torch.constant.int 6 - %1370 = torch.prims.convert_element_type %1317, %int6_1375 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %1370, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %1371 = torch_c.to_builtin_tensor %1370 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %1372 = torch_c.to_builtin_tensor %1369 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %1373 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%1371, %1372) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %1374 = torch_c.from_builtin_tensor %1373 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %1374, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_1376 = torch.constant.int 5 - %1375 = torch.prims.convert_element_type %1374, %int5_1376 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1375, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_1377 = torch.constant.int 64 - %1376 = torch.aten.mul.Scalar %arg2, %int64_1377 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1376, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int10 = torch.constant.int 10 - %int1_1378 = torch.constant.int 1 - %1377 = torch.aten.add.Scalar %1376, %int10, %int1_1378 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1377, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_1379 = torch.constant.int 4 - %int32_1380 = torch.constant.int 32 - %int8_1381 = torch.constant.int 8 - %int128_1382 = torch.constant.int 128 - %1378 = torch.prim.ListConstruct %int4_1379, %398, %int32_1380, %int8_1381, %int128_1382 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1379 = torch.aten.view %1375, %1378 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1379, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %1348 = torch.aten.clone %1347, %int0_1352 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1348, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_1353 = torch.constant.int 4 + %int32_1354 = torch.constant.int 32 + %int128_1355 = torch.constant.int 128 + %1349 = torch.prim.ListConstruct %int4_1353, %298, %int32_1354, %int128_1355 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1350 = torch.aten._unsafe_view %1348, %1349 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1350, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_1356 = torch.constant.int -2 + %1351 = torch.aten.unsqueeze %1174, %int-2_1356 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1351, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_1357 = torch.constant.int 4 + %int8_1358 = torch.constant.int 8 + %int4_1359 = torch.constant.int 4 + %int128_1360 = torch.constant.int 128 + %1352 = torch.prim.ListConstruct %int4_1357, %298, %int8_1358, %int4_1359, %int128_1360 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_1361 = torch.constant.bool false + %1353 = torch.aten.expand %1351, %1352, %false_1361 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1353, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_1362 = torch.constant.int 0 + %1354 = torch.aten.clone %1353, %int0_1362 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1354, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_1363 = torch.constant.int 4 + %int32_1364 = torch.constant.int 32 + %int128_1365 = torch.constant.int 128 + %1355 = torch.prim.ListConstruct %int4_1363, %298, %int32_1364, %int128_1365 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1356 = torch.aten._unsafe_view %1354, %1355 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1356, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_1366 = torch.constant.int 1 + %int2_1367 = torch.constant.int 2 + %1357 = torch.aten.transpose.int %1237, %int1_1366, %int2_1367 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1357, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_1368 = torch.constant.int 1 + %int2_1369 = torch.constant.int 2 + %1358 = torch.aten.transpose.int %1350, %int1_1368, %int2_1369 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1358, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_1370 = torch.constant.int 1 + %int2_1371 = torch.constant.int 2 + %1359 = torch.aten.transpose.int %1356, %int1_1370, %int2_1371 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1359, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_1372 = torch.constant.float 0.000000e+00 + %false_1373 = torch.constant.bool false + %none_1374 = torch.constant.none + %1360:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1357, %1358, %1359, %float0.000000e00_1372, %false_1373, %327, %none_1374) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %1360#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_1375 = torch.constant.int 1 + %int2_1376 = torch.constant.int 2 + %1361 = torch.aten.transpose.int %1360#0, %int1_1375, %int2_1376 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1361, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_1377 = torch.constant.int 4 + %int4096_1378 = torch.constant.int 4096 + %1362 = torch.prim.ListConstruct %int4_1377, %298, %int4096_1378 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1363 = torch.aten.view %1361, %1362 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1363, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_1379 = torch.constant.int -2 + %int-1_1380 = torch.constant.int -1 + %1364 = torch.aten.transpose.int %33, %int-2_1379, %int-1_1380 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_1381 = torch.constant.int 5 + %1365 = torch.prims.convert_element_type %1364, %int5_1381 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_1382 = torch.constant.int 4096 + %1366 = torch.prim.ListConstruct %342, %int4096_1382 : (!torch.int, !torch.int) -> !torch.list + %1367 = torch.aten.view %1363, %1366 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1367, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1368 = torch.aten.mm %1367, %1365 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1368, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> %int4_1383 = torch.constant.int 4 - %1380 = torch.aten.mul.int %int4_1383, %398 : !torch.int, !torch.int -> !torch.int - %int32_1384 = torch.constant.int 32 - %int8_1385 = torch.constant.int 8 - %int128_1386 = torch.constant.int 128 - %1381 = torch.prim.ListConstruct %1380, %int32_1384, %int8_1385, %int128_1386 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1382 = torch.aten.view %1379, %1381 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1382, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_1387 = torch.constant.int 4 - %1383 = torch.aten.mul.int %int4_1387, %398 : !torch.int, !torch.int -> !torch.int - %1384 = torch.prim.ListConstruct %1383 : (!torch.int) -> !torch.list - %1385 = torch.aten.view %1377, %1384 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1385, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_1388 = torch.constant.int 32 - %int2_1389 = torch.constant.int 2 - %int32_1390 = torch.constant.int 32 - %int8_1391 = torch.constant.int 8 - %int128_1392 = torch.constant.int 128 - %1386 = torch.prim.ListConstruct %389, %int32_1388, %int2_1389, %int32_1390, %int8_1391, %int128_1392 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1387 = torch.aten.view %1219, %1386 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1387, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_1393 = torch.constant.int 32 - %1388 = torch.aten.mul.int %389, %int32_1393 : !torch.int, !torch.int -> !torch.int - %int2_1394 = torch.constant.int 2 - %1389 = torch.aten.mul.int %1388, %int2_1394 : !torch.int, !torch.int -> !torch.int - %int32_1395 = torch.constant.int 32 - %int8_1396 = torch.constant.int 8 - %int128_1397 = torch.constant.int 128 - %1390 = torch.prim.ListConstruct %1389, %int32_1395, %int8_1396, %int128_1397 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1391 = torch.aten.view %1387, %1390 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1391, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %1392 = torch.prim.ListConstruct %1385 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_1398 = torch.constant.bool false - %1393 = torch.aten.index_put %1391, %1392, %1382, %false_1398 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1393, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_1399 = torch.constant.int 32 - %int2_1400 = torch.constant.int 2 - %int32_1401 = torch.constant.int 32 - %int8_1402 = torch.constant.int 8 - %int128_1403 = torch.constant.int 128 - %1394 = torch.prim.ListConstruct %389, %int32_1399, %int2_1400, %int32_1401, %int8_1402, %int128_1403 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1395 = torch.aten.view %1393, %1394 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1395, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_1404 = torch.constant.int 2097152 - %1396 = torch.prim.ListConstruct %389, %int2097152_1404 : (!torch.int, !torch.int) -> !torch.list - %1397 = torch.aten.view %1395, %1396 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1397, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_1405 = torch.constant.int 32 - %int2_1406 = torch.constant.int 2 - %int32_1407 = torch.constant.int 32 - %int8_1408 = torch.constant.int 8 - %int128_1409 = torch.constant.int 128 - %1398 = torch.prim.ListConstruct %389, %int32_1405, %int2_1406, %int32_1407, %int8_1408, %int128_1409 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1399 = torch.aten.view %1397, %1398 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1399, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_1410 = torch.constant.int 32 - %int8_1411 = torch.constant.int 8 - %int128_1412 = torch.constant.int 128 - %1400 = torch.prim.ListConstruct %1389, %int32_1410, %int8_1411, %int128_1412 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1401 = torch.aten.view %1399, %1400 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1401, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_1413 = torch.constant.int 4 - %int32_1414 = torch.constant.int 32 - %int8_1415 = torch.constant.int 8 - %int128_1416 = torch.constant.int 128 - %1402 = torch.prim.ListConstruct %int4_1413, %398, %int32_1414, %int8_1415, %int128_1416 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1403 = torch.aten.view %1319, %1402 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1403, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_1417 = torch.constant.int 4 - %1404 = torch.aten.mul.int %int4_1417, %398 : !torch.int, !torch.int -> !torch.int - %int32_1418 = torch.constant.int 32 - %int8_1419 = torch.constant.int 8 - %int128_1420 = torch.constant.int 128 - %1405 = torch.prim.ListConstruct %1404, %int32_1418, %int8_1419, %int128_1420 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1406 = torch.aten.view %1403, %1405 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1406, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_1421 = torch.constant.int 1 - %int1_1422 = torch.constant.int 1 - %1407 = torch.aten.add.Scalar %1377, %int1_1421, %int1_1422 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1407, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_1423 = torch.constant.int 4 - %1408 = torch.aten.mul.int %int4_1423, %398 : !torch.int, !torch.int -> !torch.int - %1409 = torch.prim.ListConstruct %1408 : (!torch.int) -> !torch.list - %1410 = torch.aten.view %1407, %1409 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1410, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %1411 = torch.prim.ListConstruct %1410 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_1424 = torch.constant.bool false - %1412 = torch.aten.index_put %1401, %1411, %1406, %false_1424 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1412, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_1425 = torch.constant.int 32 - %int2_1426 = torch.constant.int 2 - %int32_1427 = torch.constant.int 32 - %int8_1428 = torch.constant.int 8 - %int128_1429 = torch.constant.int 128 - %1413 = torch.prim.ListConstruct %389, %int32_1425, %int2_1426, %int32_1427, %int8_1428, %int128_1429 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1414 = torch.aten.view %1412, %1413 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1414, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_1430 = torch.constant.int 2097152 - %1415 = torch.prim.ListConstruct %389, %int2097152_1430 : (!torch.int, !torch.int) -> !torch.list - %1416 = torch.aten.view %1414, %1415 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1416, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_1431 = torch.constant.int -2 - %1417 = torch.aten.unsqueeze %1375, %int-2_1431 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1417, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_1432 = torch.constant.int 4 - %int8_1433 = torch.constant.int 8 - %int4_1434 = torch.constant.int 4 - %int128_1435 = torch.constant.int 128 - %1418 = torch.prim.ListConstruct %int4_1432, %1360, %int8_1433, %int4_1434, %int128_1435 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1436 = torch.constant.bool false - %1419 = torch.aten.expand %1417, %1418, %false_1436 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1419, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1437 = torch.constant.int 0 - %1420 = torch.aten.clone %1419, %int0_1437 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1420, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_1438 = torch.constant.int 4 - %int32_1439 = torch.constant.int 32 - %int128_1440 = torch.constant.int 128 - %1421 = torch.prim.ListConstruct %int4_1438, %1360, %int32_1439, %int128_1440 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1422 = torch.aten._unsafe_view %1420, %1421 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1422, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_1441 = torch.constant.int -2 - %1423 = torch.aten.unsqueeze %1319, %int-2_1441 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1423, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_1442 = torch.constant.int 1 - %1424 = torch.aten.size.int %1313, %int1_1442 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_1443 = torch.constant.int 4 - %int8_1444 = torch.constant.int 8 - %int4_1445 = torch.constant.int 4 + %int4096_1384 = torch.constant.int 4096 + %1369 = torch.prim.ListConstruct %int4_1383, %298, %int4096_1384 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1370 = torch.aten.view %1368, %1369 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1370, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_1385 = torch.constant.int 1 + %1371 = torch.aten.add.Tensor %1137, %1370, %int1_1385 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1371, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_1386 = torch.constant.int 6 + %1372 = torch.prims.convert_element_type %1371, %int6_1386 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1372, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_1387 = torch.constant.int 2 + %1373 = torch.aten.pow.Tensor_Scalar %1372, %int2_1387 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1373, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_1388 = torch.constant.int -1 + %1374 = torch.prim.ListConstruct %int-1_1388 : (!torch.int) -> !torch.list + %true_1389 = torch.constant.bool true + %none_1390 = torch.constant.none + %1375 = torch.aten.mean.dim %1373, %1374, %true_1389, %none_1390 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1375, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_1391 = torch.constant.float 9.9999997473787516E-6 + %int1_1392 = torch.constant.int 1 + %1376 = torch.aten.add.Scalar %1375, %float9.999990e-06_1391, %int1_1392 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1376, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1377 = torch.aten.rsqrt %1376 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1377, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1378 = torch.aten.mul.Tensor %1372, %1377 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1378, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_1393 = torch.constant.int 5 + %1379 = torch.prims.convert_element_type %1378, %int5_1393 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1379, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %1380 = torch.aten.mul.Tensor %34, %1379 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1380, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_1394 = torch.constant.int 5 + %1381 = torch.prims.convert_element_type %1380, %int5_1394 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1381, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_1395 = torch.constant.int -2 + %int-1_1396 = torch.constant.int -1 + %1382 = torch.aten.transpose.int %35, %int-2_1395, %int-1_1396 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_1397 = torch.constant.int 5 + %1383 = torch.prims.convert_element_type %1382, %int5_1397 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_1398 = torch.constant.int 4096 + %1384 = torch.prim.ListConstruct %342, %int4096_1398 : (!torch.int, !torch.int) -> !torch.list + %1385 = torch.aten.view %1381, %1384 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1385, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1386 = torch.aten.mm %1385, %1383 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %1386, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_1399 = torch.constant.int 4 + %int14336_1400 = torch.constant.int 14336 + %1387 = torch.prim.ListConstruct %int4_1399, %298, %int14336_1400 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1388 = torch.aten.view %1386, %1387 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1388, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %1389 = torch.aten.silu %1388 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1389, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_1401 = torch.constant.int -2 + %int-1_1402 = torch.constant.int -1 + %1390 = torch.aten.transpose.int %36, %int-2_1401, %int-1_1402 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_1403 = torch.constant.int 5 + %1391 = torch.prims.convert_element_type %1390, %int5_1403 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_1404 = torch.constant.int 4096 + %1392 = torch.prim.ListConstruct %342, %int4096_1404 : (!torch.int, !torch.int) -> !torch.list + %1393 = torch.aten.view %1381, %1392 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1393, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1394 = torch.aten.mm %1393, %1391 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %1394, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_1405 = torch.constant.int 4 + %int14336_1406 = torch.constant.int 14336 + %1395 = torch.prim.ListConstruct %int4_1405, %298, %int14336_1406 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1396 = torch.aten.view %1394, %1395 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1396, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %1397 = torch.aten.mul.Tensor %1389, %1396 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1397, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_1407 = torch.constant.int -2 + %int-1_1408 = torch.constant.int -1 + %1398 = torch.aten.transpose.int %37, %int-2_1407, %int-1_1408 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_1409 = torch.constant.int 5 + %1399 = torch.prims.convert_element_type %1398, %int5_1409 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_1410 = torch.constant.int 14336 + %1400 = torch.prim.ListConstruct %342, %int14336_1410 : (!torch.int, !torch.int) -> !torch.list + %1401 = torch.aten.view %1397, %1400 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %1401, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %1402 = torch.aten.mm %1401, %1399 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1402, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_1411 = torch.constant.int 4 + %int4096_1412 = torch.constant.int 4096 + %1403 = torch.prim.ListConstruct %int4_1411, %298, %int4096_1412 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1404 = torch.aten.view %1402, %1403 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1404, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_1413 = torch.constant.int 1 + %1405 = torch.aten.add.Tensor %1371, %1404, %int1_1413 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1405, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_1414 = torch.constant.int 6 + %1406 = torch.prims.convert_element_type %1405, %int6_1414 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1406, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_1415 = torch.constant.int 2 + %1407 = torch.aten.pow.Tensor_Scalar %1406, %int2_1415 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1407, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_1416 = torch.constant.int -1 + %1408 = torch.prim.ListConstruct %int-1_1416 : (!torch.int) -> !torch.list + %true_1417 = torch.constant.bool true + %none_1418 = torch.constant.none + %1409 = torch.aten.mean.dim %1407, %1408, %true_1417, %none_1418 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1409, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_1419 = torch.constant.float 9.9999997473787516E-6 + %int1_1420 = torch.constant.int 1 + %1410 = torch.aten.add.Scalar %1409, %float9.999990e-06_1419, %int1_1420 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1410, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1411 = torch.aten.rsqrt %1410 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1411, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1412 = torch.aten.mul.Tensor %1406, %1411 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1412, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_1421 = torch.constant.int 5 + %1413 = torch.prims.convert_element_type %1412, %int5_1421 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1413, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %1414 = torch.aten.mul.Tensor %38, %1413 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1414, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_1422 = torch.constant.int 5 + %1415 = torch.prims.convert_element_type %1414, %int5_1422 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1415, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_1423 = torch.constant.int -2 + %int-1_1424 = torch.constant.int -1 + %1416 = torch.aten.transpose.int %39, %int-2_1423, %int-1_1424 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_1425 = torch.constant.int 5 + %1417 = torch.prims.convert_element_type %1416, %int5_1425 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_1426 = torch.constant.int 4096 + %1418 = torch.prim.ListConstruct %342, %int4096_1426 : (!torch.int, !torch.int) -> !torch.list + %1419 = torch.aten.view %1415, %1418 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1419, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1420 = torch.aten.mm %1419, %1417 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1420, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_1427 = torch.constant.int 4 + %int4096_1428 = torch.constant.int 4096 + %1421 = torch.prim.ListConstruct %int4_1427, %298, %int4096_1428 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1422 = torch.aten.view %1420, %1421 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1422, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_1429 = torch.constant.int -2 + %int-1_1430 = torch.constant.int -1 + %1423 = torch.aten.transpose.int %40, %int-2_1429, %int-1_1430 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_1431 = torch.constant.int 5 + %1424 = torch.prims.convert_element_type %1423, %int5_1431 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_1432 = torch.constant.int 4096 + %1425 = torch.prim.ListConstruct %342, %int4096_1432 : (!torch.int, !torch.int) -> !torch.list + %1426 = torch.aten.view %1415, %1425 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1426, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1427 = torch.aten.mm %1426, %1424 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %1427, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_1433 = torch.constant.int 4 + %int1024_1434 = torch.constant.int 1024 + %1428 = torch.prim.ListConstruct %int4_1433, %298, %int1024_1434 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1429 = torch.aten.view %1427, %1428 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %1429, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_1435 = torch.constant.int -2 + %int-1_1436 = torch.constant.int -1 + %1430 = torch.aten.transpose.int %41, %int-2_1435, %int-1_1436 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_1437 = torch.constant.int 5 + %1431 = torch.prims.convert_element_type %1430, %int5_1437 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_1438 = torch.constant.int 4096 + %1432 = torch.prim.ListConstruct %342, %int4096_1438 : (!torch.int, !torch.int) -> !torch.list + %1433 = torch.aten.view %1415, %1432 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1433, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1434 = torch.aten.mm %1433, %1431 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %1434, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_1439 = torch.constant.int 4 + %int1024_1440 = torch.constant.int 1024 + %1435 = torch.prim.ListConstruct %int4_1439, %298, %int1024_1440 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1436 = torch.aten.view %1434, %1435 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %1436, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_1441 = torch.constant.int 4 + %int32_1442 = torch.constant.int 32 + %int128_1443 = torch.constant.int 128 + %1437 = torch.prim.ListConstruct %int4_1441, %298, %int32_1442, %int128_1443 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1438 = torch.aten.view %1422, %1437 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1438, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_1444 = torch.constant.int 4 + %int8_1445 = torch.constant.int 8 %int128_1446 = torch.constant.int 128 - %1425 = torch.prim.ListConstruct %int4_1443, %1424, %int8_1444, %int4_1445, %int128_1446 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1447 = torch.constant.bool false - %1426 = torch.aten.expand %1423, %1425, %false_1447 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1426, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1448 = torch.constant.int 0 - %1427 = torch.aten.clone %1426, %int0_1448 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1427, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_1449 = torch.constant.int 4 - %int32_1450 = torch.constant.int 32 - %int128_1451 = torch.constant.int 128 - %1428 = torch.prim.ListConstruct %int4_1449, %1424, %int32_1450, %int128_1451 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1429 = torch.aten._unsafe_view %1427, %1428 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1429, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_1452 = torch.constant.int 1 - %int2_1453 = torch.constant.int 2 - %1430 = torch.aten.transpose.int %1347, %int1_1452, %int2_1453 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1430, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_1454 = torch.constant.int 1 - %int2_1455 = torch.constant.int 2 - %1431 = torch.aten.transpose.int %1422, %int1_1454, %int2_1455 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1431, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_1456 = torch.constant.int 1 + %1439 = torch.prim.ListConstruct %int4_1444, %298, %int8_1445, %int128_1446 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1440 = torch.aten.view %1429, %1439 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1440, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_1447 = torch.constant.int 4 + %int8_1448 = torch.constant.int 8 + %int128_1449 = torch.constant.int 128 + %1441 = torch.prim.ListConstruct %int4_1447, %298, %int8_1448, %int128_1449 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1442 = torch.aten.view %1436, %1441 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1442, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_1450 = torch.constant.int 131072 + %none_1451 = torch.constant.none + %none_1452 = torch.constant.none + %cpu_1453 = torch.constant.device "cpu" + %false_1454 = torch.constant.bool false + %1443 = torch.aten.arange %int131072_1450, %none_1451, %none_1452, %cpu_1453, %false_1454 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_1455 = torch.constant.int 0 + %int128_1456 = torch.constant.int 128 %int2_1457 = torch.constant.int 2 - %1432 = torch.aten.transpose.int %1429, %int1_1456, %int2_1457 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1432, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_1458 = torch.constant.float 0.000000e+00 - %true_1459 = torch.constant.bool true - %none_1460 = torch.constant.none - %none_1461 = torch.constant.none - %1433:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1430, %1431, %1432, %float0.000000e00_1458, %true_1459, %none_1460, %none_1461) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %1433#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_1462 = torch.constant.int 1 - %int2_1463 = torch.constant.int 2 - %1434 = torch.aten.transpose.int %1433#0, %int1_1462, %int2_1463 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1434, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_1464 = torch.constant.int 4 - %int4096_1465 = torch.constant.int 4096 - %1435 = torch.prim.ListConstruct %int4_1464, %1332, %int4096_1465 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1436 = torch.aten.view %1434, %1435 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1436, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1466 = torch.constant.int -2 - %int-1_1467 = torch.constant.int -1 - %1437 = torch.aten.transpose.int %50, %int-2_1466, %int-1_1467 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1468 = torch.constant.int 4 - %1438 = torch.aten.mul.int %int4_1468, %1332 : !torch.int, !torch.int -> !torch.int - %int4096_1469 = torch.constant.int 4096 - %1439 = torch.prim.ListConstruct %1438, %int4096_1469 : (!torch.int, !torch.int) -> !torch.list - %1440 = torch.aten.view %1436, %1439 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1440, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1441 = torch.aten.mm %1440, %1437 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1441, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_1470 = torch.constant.int 4 - %int4096_1471 = torch.constant.int 4096 - %1442 = torch.prim.ListConstruct %int4_1470, %1332, %int4096_1471 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1443 = torch.aten.view %1441, %1442 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1443, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_1472 = torch.constant.int 1 - %1444 = torch.aten.add.Tensor %1282, %1443, %int1_1472 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1444, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_1473 = torch.constant.int 6 - %1445 = torch.prims.convert_element_type %1444, %int6_1473 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1445, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_1474 = torch.constant.int 2 - %1446 = torch.aten.pow.Tensor_Scalar %1445, %int2_1474 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1446, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_1475 = torch.constant.int -1 - %1447 = torch.prim.ListConstruct %int-1_1475 : (!torch.int) -> !torch.list - %true_1476 = torch.constant.bool true - %none_1477 = torch.constant.none - %1448 = torch.aten.mean.dim %1446, %1447, %true_1476, %none_1477 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1448, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_1478 = torch.constant.float 9.9999997473787516E-6 - %int1_1479 = torch.constant.int 1 - %1449 = torch.aten.add.Scalar %1448, %float9.999990e-06_1478, %int1_1479 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1449, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1450 = torch.aten.rsqrt %1449 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1450, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1451 = torch.aten.mul.Tensor %1445, %1450 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1451, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1480 = torch.constant.int 5 - %1452 = torch.prims.convert_element_type %1451, %int5_1480 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1452, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %1453 = torch.aten.mul.Tensor %51, %1452 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1453, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1481 = torch.constant.int 5 - %1454 = torch.prims.convert_element_type %1453, %int5_1481 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1454, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1482 = torch.constant.int -2 - %int-1_1483 = torch.constant.int -1 - %1455 = torch.aten.transpose.int %52, %int-2_1482, %int-1_1483 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1484 = torch.constant.int 4 - %1456 = torch.aten.mul.int %int4_1484, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1485 = torch.constant.int 4096 - %1457 = torch.prim.ListConstruct %1456, %int4096_1485 : (!torch.int, !torch.int) -> !torch.list - %1458 = torch.aten.view %1454, %1457 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1458, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1459 = torch.aten.mm %1458, %1455 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1459, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_1486 = torch.constant.int 4 - %int14336_1487 = torch.constant.int 14336 - %1460 = torch.prim.ListConstruct %int4_1486, %306, %int14336_1487 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1461 = torch.aten.view %1459, %1460 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1461, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %1462 = torch.aten.silu %1461 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1462, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_1488 = torch.constant.int -2 - %int-1_1489 = torch.constant.int -1 - %1463 = torch.aten.transpose.int %53, %int-2_1488, %int-1_1489 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1490 = torch.constant.int 4 - %1464 = torch.aten.mul.int %int4_1490, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1491 = torch.constant.int 4096 - %1465 = torch.prim.ListConstruct %1464, %int4096_1491 : (!torch.int, !torch.int) -> !torch.list - %1466 = torch.aten.view %1454, %1465 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1466, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1467 = torch.aten.mm %1466, %1463 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1467, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_1492 = torch.constant.int 4 - %int14336_1493 = torch.constant.int 14336 - %1468 = torch.prim.ListConstruct %int4_1492, %306, %int14336_1493 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1469 = torch.aten.view %1467, %1468 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1469, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %1470 = torch.aten.mul.Tensor %1462, %1469 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1470, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_1494 = torch.constant.int -2 - %int-1_1495 = torch.constant.int -1 - %1471 = torch.aten.transpose.int %54, %int-2_1494, %int-1_1495 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_1458 = torch.constant.int 4 + %none_1459 = torch.constant.none + %cpu_1460 = torch.constant.device "cpu" + %false_1461 = torch.constant.bool false + %1444 = torch.aten.arange.start_step %int0_1455, %int128_1456, %int2_1457, %int4_1458, %none_1459, %cpu_1460, %false_1461 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_1462 = torch.constant.int 6 + %1445 = torch.prims.convert_element_type %1444, %int6_1462 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_1463 = torch.constant.int 128 + %1446 = torch.aten.div.Scalar %1445, %int128_1463 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_1464 = torch.constant.float 5.000000e+05 + %1447 = torch.aten.pow.Scalar %float5.000000e05_1464, %1446 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1448 = torch.aten.reciprocal %1447 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_1465 = torch.constant.float 1.000000e+00 + %1449 = torch.aten.mul.Scalar %1448, %float1.000000e00_1465 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %1450 = torch.aten.reciprocal %1449 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_1466 = torch.constant.float 6.2831853071795862 + %1451 = torch.aten.mul.Scalar %1450, %float6.283190e00_1466 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_1467 = torch.constant.float 8.192000e+03 + %1452 = torch.aten.gt.Scalar %1451, %float8.192000e03_1467 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_1468 = torch.constant.int 8 + %1453 = torch.aten.div.Scalar %1449, %int8_1468 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %1454 = torch.aten.where.self %1452, %1453, %1449 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1455 = torch.aten.reciprocal %1451 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_1469 = torch.constant.int 8192 + %1456 = torch.aten.mul.Scalar %1455, %int8192_1469 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_1470 = torch.constant.int 1 + %int1_1471 = torch.constant.int 1 + %1457 = torch.aten.sub.Scalar %1456, %int1_1470, %int1_1471 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_1472 = torch.constant.int 3 + %1458 = torch.aten.div.Scalar %1457, %int3_1472 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_1473 = torch.constant.int 1 + %int1_1474 = torch.constant.int 1 + %1459 = torch.aten.rsub.Scalar %1458, %int1_1473, %int1_1474 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %1460 = torch.aten.mul.Tensor %1459, %1454 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_1475 = torch.constant.int 8 + %1461 = torch.aten.div.Scalar %1460, %int8_1475 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %1462 = torch.aten.mul.Tensor %1458, %1454 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_1476 = torch.constant.int 1 + %1463 = torch.aten.add.Tensor %1461, %1462, %int1_1476 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_1477 = torch.constant.float 2.048000e+03 + %1464 = torch.aten.lt.Scalar %1451, %float2.048000e03_1477 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %1465 = torch.aten.bitwise_not %1464 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_1478 = torch.constant.float 8.192000e+03 + %1466 = torch.aten.gt.Scalar %1451, %float8.192000e03_1478 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %1467 = torch.aten.bitwise_not %1466 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %1468 = torch.aten.mul.Tensor %1465, %1467 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %1469 = torch.aten.where.self %1468, %1463, %1454 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1470 = torch.prim.ListConstruct %1469, %1469 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_1479 = torch.constant.int -1 + %1471 = torch.aten.cat %1470, %int-1_1479 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_1480 = torch.constant.int 6 + %1472 = torch.prims.convert_element_type %1471, %int6_1480 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_1481 = torch.constant.int 1 + %1473 = torch.aten.unsqueeze %1443, %int1_1481 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_1482 = torch.constant.int 6 + %1474 = torch.prims.convert_element_type %1473, %int6_1482 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_1483 = torch.constant.int 0 + %1475 = torch.aten.unsqueeze %1472, %int0_1483 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_1484 = torch.constant.int 6 + %1476 = torch.prims.convert_element_type %1475, %int6_1484 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %1477 = torch.aten.mul.Tensor %1474, %1476 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %1478 = torch.aten.cos %1477 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_1485 = torch.constant.int 5 + %1479 = torch.prims.convert_element_type %1478, %int5_1485 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %1480 = torch.aten.sin %1477 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_1486 = torch.constant.int 5 + %1481 = torch.prims.convert_element_type %1480, %int5_1486 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_1487 = torch.constant.int 0 + %int0_1488 = torch.constant.int 0 + %int1_1489 = torch.constant.int 1 + %1482 = torch.aten.slice.Tensor %1479, %int0_1487, %int0_1488, %298, %int1_1489 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1482, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_1490 = torch.constant.int 1 + %int0_1491 = torch.constant.int 0 + %int9223372036854775807_1492 = torch.constant.int 9223372036854775807 + %int1_1493 = torch.constant.int 1 + %1483 = torch.aten.slice.Tensor %1482, %int1_1490, %int0_1491, %int9223372036854775807_1492, %int1_1493 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1483, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_1494 = torch.constant.int 0 + %int0_1495 = torch.constant.int 0 %int1_1496 = torch.constant.int 1 - %1472 = torch.aten.size.int %1461, %int1_1496 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_1497 = torch.constant.int 4 - %1473 = torch.aten.mul.int %int4_1497, %1472 : !torch.int, !torch.int -> !torch.int - %int14336_1498 = torch.constant.int 14336 - %1474 = torch.prim.ListConstruct %1473, %int14336_1498 : (!torch.int, !torch.int) -> !torch.list - %1475 = torch.aten.view %1470, %1474 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1475, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %1476 = torch.aten.mm %1475, %1471 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1476, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_1499 = torch.constant.int 4 - %int4096_1500 = torch.constant.int 4096 - %1477 = torch.prim.ListConstruct %int4_1499, %1472, %int4096_1500 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1478 = torch.aten.view %1476, %1477 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1478, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_1501 = torch.constant.int 1 - %1479 = torch.aten.add.Tensor %1444, %1478, %int1_1501 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1479, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_1502 = torch.constant.int 6 - %1480 = torch.prims.convert_element_type %1479, %int6_1502 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1480, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_1503 = torch.constant.int 2 - %1481 = torch.aten.pow.Tensor_Scalar %1480, %int2_1503 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1481, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_1504 = torch.constant.int -1 - %1482 = torch.prim.ListConstruct %int-1_1504 : (!torch.int) -> !torch.list - %true_1505 = torch.constant.bool true - %none_1506 = torch.constant.none - %1483 = torch.aten.mean.dim %1481, %1482, %true_1505, %none_1506 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1483, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_1507 = torch.constant.float 9.9999997473787516E-6 - %int1_1508 = torch.constant.int 1 - %1484 = torch.aten.add.Scalar %1483, %float9.999990e-06_1507, %int1_1508 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1484, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1485 = torch.aten.rsqrt %1484 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1485, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1486 = torch.aten.mul.Tensor %1480, %1485 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1486, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1509 = torch.constant.int 5 - %1487 = torch.prims.convert_element_type %1486, %int5_1509 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1487, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %1488 = torch.aten.mul.Tensor %55, %1487 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1488, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1510 = torch.constant.int 5 - %1489 = torch.prims.convert_element_type %1488, %int5_1510 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1489, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1511 = torch.constant.int -2 - %int-1_1512 = torch.constant.int -1 - %1490 = torch.aten.transpose.int %56, %int-2_1511, %int-1_1512 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1513 = torch.constant.int 4 - %1491 = torch.aten.mul.int %int4_1513, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1514 = torch.constant.int 4096 - %1492 = torch.prim.ListConstruct %1491, %int4096_1514 : (!torch.int, !torch.int) -> !torch.list - %1493 = torch.aten.view %1489, %1492 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1493, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1494 = torch.aten.mm %1493, %1490 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1494, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_1515 = torch.constant.int 4 - %int4096_1516 = torch.constant.int 4096 - %1495 = torch.prim.ListConstruct %int4_1515, %306, %int4096_1516 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1496 = torch.aten.view %1494, %1495 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1496, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1517 = torch.constant.int -2 - %int-1_1518 = torch.constant.int -1 - %1497 = torch.aten.transpose.int %57, %int-2_1517, %int-1_1518 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_1519 = torch.constant.int 4 - %1498 = torch.aten.mul.int %int4_1519, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1520 = torch.constant.int 4096 - %1499 = torch.prim.ListConstruct %1498, %int4096_1520 : (!torch.int, !torch.int) -> !torch.list - %1500 = torch.aten.view %1489, %1499 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1500, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1501 = torch.aten.mm %1500, %1497 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %1501, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_1521 = torch.constant.int 4 - %int1024_1522 = torch.constant.int 1024 - %1502 = torch.prim.ListConstruct %int4_1521, %306, %int1024_1522 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1503 = torch.aten.view %1501, %1502 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %1503, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_1523 = torch.constant.int -2 - %int-1_1524 = torch.constant.int -1 - %1504 = torch.aten.transpose.int %58, %int-2_1523, %int-1_1524 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %1484 = torch.aten.slice.Tensor %1481, %int0_1494, %int0_1495, %298, %int1_1496 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1484, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_1497 = torch.constant.int 1 + %int0_1498 = torch.constant.int 0 + %int9223372036854775807_1499 = torch.constant.int 9223372036854775807 + %int1_1500 = torch.constant.int 1 + %1485 = torch.aten.slice.Tensor %1484, %int1_1497, %int0_1498, %int9223372036854775807_1499, %int1_1500 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1485, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_1501 = torch.constant.int 0 + %1486 = torch.aten.unsqueeze %1483, %int0_1501 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1486, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_1502 = torch.constant.int 1 + %int0_1503 = torch.constant.int 0 + %int9223372036854775807_1504 = torch.constant.int 9223372036854775807 + %int1_1505 = torch.constant.int 1 + %1487 = torch.aten.slice.Tensor %1486, %int1_1502, %int0_1503, %int9223372036854775807_1504, %int1_1505 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1487, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_1506 = torch.constant.int 2 + %1488 = torch.aten.unsqueeze %1487, %int2_1506 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1488, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_1507 = torch.constant.int 3 + %int0_1508 = torch.constant.int 0 + %int9223372036854775807_1509 = torch.constant.int 9223372036854775807 + %int1_1510 = torch.constant.int 1 + %1489 = torch.aten.slice.Tensor %1488, %int3_1507, %int0_1508, %int9223372036854775807_1509, %int1_1510 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1489, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_1511 = torch.constant.int 4 + %int1_1512 = torch.constant.int 1 + %int1_1513 = torch.constant.int 1 + %int1_1514 = torch.constant.int 1 + %1490 = torch.prim.ListConstruct %int4_1511, %int1_1512, %int1_1513, %int1_1514 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1491 = torch.aten.repeat %1489, %1490 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %1491, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_1515 = torch.constant.int 0 + %1492 = torch.aten.unsqueeze %1485, %int0_1515 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1492, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_1516 = torch.constant.int 1 + %int0_1517 = torch.constant.int 0 + %int9223372036854775807_1518 = torch.constant.int 9223372036854775807 + %int1_1519 = torch.constant.int 1 + %1493 = torch.aten.slice.Tensor %1492, %int1_1516, %int0_1517, %int9223372036854775807_1518, %int1_1519 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1493, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_1520 = torch.constant.int 2 + %1494 = torch.aten.unsqueeze %1493, %int2_1520 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1494, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_1521 = torch.constant.int 3 + %int0_1522 = torch.constant.int 0 + %int9223372036854775807_1523 = torch.constant.int 9223372036854775807 + %int1_1524 = torch.constant.int 1 + %1495 = torch.aten.slice.Tensor %1494, %int3_1521, %int0_1522, %int9223372036854775807_1523, %int1_1524 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1495, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> %int4_1525 = torch.constant.int 4 - %1505 = torch.aten.mul.int %int4_1525, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1526 = torch.constant.int 4096 - %1506 = torch.prim.ListConstruct %1505, %int4096_1526 : (!torch.int, !torch.int) -> !torch.list - %1507 = torch.aten.view %1489, %1506 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1507, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1508 = torch.aten.mm %1507, %1504 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %1508, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_1527 = torch.constant.int 4 - %int1024_1528 = torch.constant.int 1024 - %1509 = torch.prim.ListConstruct %int4_1527, %306, %int1024_1528 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1510 = torch.aten.view %1508, %1509 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %1510, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_1529 = torch.constant.int 4 - %int32_1530 = torch.constant.int 32 - %int128_1531 = torch.constant.int 128 - %1511 = torch.prim.ListConstruct %int4_1529, %306, %int32_1530, %int128_1531 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1512 = torch.aten.view %1496, %1511 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1512, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_1532 = torch.constant.int 4 - %int8_1533 = torch.constant.int 8 - %int128_1534 = torch.constant.int 128 - %1513 = torch.prim.ListConstruct %int4_1532, %306, %int8_1533, %int128_1534 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1514 = torch.aten.view %1503, %1513 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1514, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_1535 = torch.constant.int 4 - %int8_1536 = torch.constant.int 8 - %int128_1537 = torch.constant.int 128 - %1515 = torch.prim.ListConstruct %int4_1535, %306, %int8_1536, %int128_1537 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1516 = torch.aten.view %1510, %1515 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1516, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_1538 = torch.constant.int 131072 - %none_1539 = torch.constant.none + %int1_1526 = torch.constant.int 1 + %int1_1527 = torch.constant.int 1 + %int1_1528 = torch.constant.int 1 + %1496 = torch.prim.ListConstruct %int4_1525, %int1_1526, %int1_1527, %int1_1528 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1497 = torch.aten.repeat %1495, %1496 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %1497, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %1498 = torch.aten.mul.Tensor %1438, %1491 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1498, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_1529 = torch.constant.int 3 + %int0_1530 = torch.constant.int 0 + %int64_1531 = torch.constant.int 64 + %int1_1532 = torch.constant.int 1 + %1499 = torch.aten.slice.Tensor %1438, %int3_1529, %int0_1530, %int64_1531, %int1_1532 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %1499, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_1533 = torch.constant.int 3 + %int64_1534 = torch.constant.int 64 + %int9223372036854775807_1535 = torch.constant.int 9223372036854775807 + %int1_1536 = torch.constant.int 1 + %1500 = torch.aten.slice.Tensor %1438, %int3_1533, %int64_1534, %int9223372036854775807_1535, %int1_1536 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %1500, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %1501 = torch.aten.neg %1500 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %1501, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %1502 = torch.prim.ListConstruct %1501, %1499 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_1537 = torch.constant.int -1 + %1503 = torch.aten.cat %1502, %int-1_1537 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1503, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %1504 = torch.aten.mul.Tensor %1503, %1497 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1504, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_1538 = torch.constant.int 1 + %1505 = torch.aten.add.Tensor %1498, %1504, %int1_1538 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1505, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_1539 = torch.constant.int 131072 %none_1540 = torch.constant.none - %cpu_1541 = torch.constant.device "cpu" - %false_1542 = torch.constant.bool false - %1517 = torch.aten.arange %int131072_1538, %none_1539, %none_1540, %cpu_1541, %false_1542 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_1543 = torch.constant.int 0 - %int128_1544 = torch.constant.int 128 - %none_1545 = torch.constant.none - %none_1546 = torch.constant.none - %cpu_1547 = torch.constant.device "cpu" - %false_1548 = torch.constant.bool false - %1518 = torch.aten.arange.start %int0_1543, %int128_1544, %none_1545, %none_1546, %cpu_1547, %false_1548 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_1549 = torch.constant.int 2 - %1519 = torch.aten.floor_divide.Scalar %1518, %int2_1549 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_1550 = torch.constant.int 6 - %1520 = torch.prims.convert_element_type %1519, %int6_1550 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_1551 = torch.constant.int 128 - %1521 = torch.aten.div.Scalar %1520, %int128_1551 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_1552 = torch.constant.float 2.000000e+00 - %1522 = torch.aten.mul.Scalar %1521, %float2.000000e00_1552 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %none_1541 = torch.constant.none + %cpu_1542 = torch.constant.device "cpu" + %false_1543 = torch.constant.bool false + %1506 = torch.aten.arange %int131072_1539, %none_1540, %none_1541, %cpu_1542, %false_1543 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_1544 = torch.constant.int 0 + %int128_1545 = torch.constant.int 128 + %int2_1546 = torch.constant.int 2 + %int4_1547 = torch.constant.int 4 + %none_1548 = torch.constant.none + %cpu_1549 = torch.constant.device "cpu" + %false_1550 = torch.constant.bool false + %1507 = torch.aten.arange.start_step %int0_1544, %int128_1545, %int2_1546, %int4_1547, %none_1548, %cpu_1549, %false_1550 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_1551 = torch.constant.int 6 + %1508 = torch.prims.convert_element_type %1507, %int6_1551 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_1552 = torch.constant.int 128 + %1509 = torch.aten.div.Scalar %1508, %int128_1552 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %float5.000000e05_1553 = torch.constant.float 5.000000e+05 - %1523 = torch.aten.pow.Scalar %float5.000000e05_1553, %1522 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %1524 = torch.aten.reciprocal %1523 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> + %1510 = torch.aten.pow.Scalar %float5.000000e05_1553, %1509 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1511 = torch.aten.reciprocal %1510 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> %float1.000000e00_1554 = torch.constant.float 1.000000e+00 - %1525 = torch.aten.mul.Scalar %1524, %float1.000000e00_1554 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_1555 = torch.constant.int 1 - %1526 = torch.aten.unsqueeze %1517, %int1_1555 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_1556 = torch.constant.int 0 - %1527 = torch.aten.unsqueeze %1525, %int0_1556 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %1528 = torch.aten.mul.Tensor %1526, %1527 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_1557 = torch.constant.int 1 - %1529 = torch.aten.size.int %1496, %int1_1557 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_1558 = torch.constant.int 0 - %1530 = torch.aten.add.int %int0_1558, %1529 : !torch.int, !torch.int -> !torch.int - %int0_1559 = torch.constant.int 0 - %int0_1560 = torch.constant.int 0 - %int1_1561 = torch.constant.int 1 - %1531 = torch.aten.slice.Tensor %1528, %int0_1559, %int0_1560, %1530, %int1_1561 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1531, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %1512 = torch.aten.mul.Scalar %1511, %float1.000000e00_1554 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %1513 = torch.aten.reciprocal %1512 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_1555 = torch.constant.float 6.2831853071795862 + %1514 = torch.aten.mul.Scalar %1513, %float6.283190e00_1555 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_1556 = torch.constant.float 8.192000e+03 + %1515 = torch.aten.gt.Scalar %1514, %float8.192000e03_1556 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_1557 = torch.constant.int 8 + %1516 = torch.aten.div.Scalar %1512, %int8_1557 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %1517 = torch.aten.where.self %1515, %1516, %1512 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1518 = torch.aten.reciprocal %1514 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_1558 = torch.constant.int 8192 + %1519 = torch.aten.mul.Scalar %1518, %int8192_1558 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_1559 = torch.constant.int 1 + %int1_1560 = torch.constant.int 1 + %1520 = torch.aten.sub.Scalar %1519, %int1_1559, %int1_1560 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_1561 = torch.constant.int 3 + %1521 = torch.aten.div.Scalar %1520, %int3_1561 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_1562 = torch.constant.int 1 - %int0_1563 = torch.constant.int 0 - %int9223372036854775807_1564 = torch.constant.int 9223372036854775807 + %int1_1563 = torch.constant.int 1 + %1522 = torch.aten.rsub.Scalar %1521, %int1_1562, %int1_1563 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %1523 = torch.aten.mul.Tensor %1522, %1517 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_1564 = torch.constant.int 8 + %1524 = torch.aten.div.Scalar %1523, %int8_1564 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %1525 = torch.aten.mul.Tensor %1521, %1517 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> %int1_1565 = torch.constant.int 1 - %1532 = torch.aten.slice.Tensor %1531, %int1_1562, %int0_1563, %int9223372036854775807_1564, %int1_1565 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1532, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_1566 = torch.constant.int 1 - %int0_1567 = torch.constant.int 0 - %int9223372036854775807_1568 = torch.constant.int 9223372036854775807 - %int1_1569 = torch.constant.int 1 - %1533 = torch.aten.slice.Tensor %1532, %int1_1566, %int0_1567, %int9223372036854775807_1568, %int1_1569 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1533, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_1570 = torch.constant.int 0 - %1534 = torch.aten.unsqueeze %1533, %int0_1570 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1534, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_1571 = torch.constant.int 1 + %1526 = torch.aten.add.Tensor %1524, %1525, %int1_1565 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_1566 = torch.constant.float 2.048000e+03 + %1527 = torch.aten.lt.Scalar %1514, %float2.048000e03_1566 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %1528 = torch.aten.bitwise_not %1527 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_1567 = torch.constant.float 8.192000e+03 + %1529 = torch.aten.gt.Scalar %1514, %float8.192000e03_1567 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %1530 = torch.aten.bitwise_not %1529 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %1531 = torch.aten.mul.Tensor %1528, %1530 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %1532 = torch.aten.where.self %1531, %1526, %1517 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1533 = torch.prim.ListConstruct %1532, %1532 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_1568 = torch.constant.int -1 + %1534 = torch.aten.cat %1533, %int-1_1568 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_1569 = torch.constant.int 6 + %1535 = torch.prims.convert_element_type %1534, %int6_1569 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_1570 = torch.constant.int 1 + %1536 = torch.aten.unsqueeze %1506, %int1_1570 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_1571 = torch.constant.int 6 + %1537 = torch.prims.convert_element_type %1536, %int6_1571 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> %int0_1572 = torch.constant.int 0 - %int9223372036854775807_1573 = torch.constant.int 9223372036854775807 - %int1_1574 = torch.constant.int 1 - %1535 = torch.aten.slice.Tensor %1534, %int1_1571, %int0_1572, %int9223372036854775807_1573, %int1_1574 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1535, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_1575 = torch.constant.int 2 + %1538 = torch.aten.unsqueeze %1535, %int0_1572 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_1573 = torch.constant.int 6 + %1539 = torch.prims.convert_element_type %1538, %int6_1573 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %1540 = torch.aten.mul.Tensor %1537, %1539 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %1541 = torch.aten.cos %1540 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_1574 = torch.constant.int 5 + %1542 = torch.prims.convert_element_type %1541, %int5_1574 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %1543 = torch.aten.sin %1540 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_1575 = torch.constant.int 5 + %1544 = torch.prims.convert_element_type %1543, %int5_1575 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> %int0_1576 = torch.constant.int 0 - %int9223372036854775807_1577 = torch.constant.int 9223372036854775807 + %int0_1577 = torch.constant.int 0 %int1_1578 = torch.constant.int 1 - %1536 = torch.aten.slice.Tensor %1535, %int2_1575, %int0_1576, %int9223372036854775807_1577, %int1_1578 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1536, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_1579 = torch.constant.int 4 - %int1_1580 = torch.constant.int 1 - %int1_1581 = torch.constant.int 1 - %1537 = torch.prim.ListConstruct %int4_1579, %int1_1580, %int1_1581 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1538 = torch.aten.repeat %1536, %1537 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %1538, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_1582 = torch.constant.int 6 - %1539 = torch.prims.convert_element_type %1512, %int6_1582 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %1539, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %1540 = torch_c.to_builtin_tensor %1539 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %1541 = torch_c.to_builtin_tensor %1538 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %1542 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%1540, %1541) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %1543 = torch_c.from_builtin_tensor %1542 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %1543, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_1583 = torch.constant.int 5 - %1544 = torch.prims.convert_element_type %1543, %int5_1583 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1544, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_1584 = torch.constant.int 131072 - %none_1585 = torch.constant.none - %none_1586 = torch.constant.none - %cpu_1587 = torch.constant.device "cpu" - %false_1588 = torch.constant.bool false - %1545 = torch.aten.arange %int131072_1584, %none_1585, %none_1586, %cpu_1587, %false_1588 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_1589 = torch.constant.int 0 - %int128_1590 = torch.constant.int 128 - %none_1591 = torch.constant.none - %none_1592 = torch.constant.none - %cpu_1593 = torch.constant.device "cpu" - %false_1594 = torch.constant.bool false - %1546 = torch.aten.arange.start %int0_1589, %int128_1590, %none_1591, %none_1592, %cpu_1593, %false_1594 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> + %1545 = torch.aten.slice.Tensor %1542, %int0_1576, %int0_1577, %298, %int1_1578 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1545, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_1579 = torch.constant.int 1 + %int0_1580 = torch.constant.int 0 + %int9223372036854775807_1581 = torch.constant.int 9223372036854775807 + %int1_1582 = torch.constant.int 1 + %1546 = torch.aten.slice.Tensor %1545, %int1_1579, %int0_1580, %int9223372036854775807_1581, %int1_1582 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1546, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_1583 = torch.constant.int 0 + %int0_1584 = torch.constant.int 0 + %int1_1585 = torch.constant.int 1 + %1547 = torch.aten.slice.Tensor %1544, %int0_1583, %int0_1584, %298, %int1_1585 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1547, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_1586 = torch.constant.int 1 + %int0_1587 = torch.constant.int 0 + %int9223372036854775807_1588 = torch.constant.int 9223372036854775807 + %int1_1589 = torch.constant.int 1 + %1548 = torch.aten.slice.Tensor %1547, %int1_1586, %int0_1587, %int9223372036854775807_1588, %int1_1589 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1548, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_1590 = torch.constant.int 0 + %1549 = torch.aten.unsqueeze %1546, %int0_1590 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1549, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_1591 = torch.constant.int 1 + %int0_1592 = torch.constant.int 0 + %int9223372036854775807_1593 = torch.constant.int 9223372036854775807 + %int1_1594 = torch.constant.int 1 + %1550 = torch.aten.slice.Tensor %1549, %int1_1591, %int0_1592, %int9223372036854775807_1593, %int1_1594 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1550, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int2_1595 = torch.constant.int 2 - %1547 = torch.aten.floor_divide.Scalar %1546, %int2_1595 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_1596 = torch.constant.int 6 - %1548 = torch.prims.convert_element_type %1547, %int6_1596 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_1597 = torch.constant.int 128 - %1549 = torch.aten.div.Scalar %1548, %int128_1597 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_1598 = torch.constant.float 2.000000e+00 - %1550 = torch.aten.mul.Scalar %1549, %float2.000000e00_1598 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_1599 = torch.constant.float 5.000000e+05 - %1551 = torch.aten.pow.Scalar %float5.000000e05_1599, %1550 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %1552 = torch.aten.reciprocal %1551 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_1600 = torch.constant.float 1.000000e+00 - %1553 = torch.aten.mul.Scalar %1552, %float1.000000e00_1600 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %1551 = torch.aten.unsqueeze %1550, %int2_1595 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1551, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_1596 = torch.constant.int 3 + %int0_1597 = torch.constant.int 0 + %int9223372036854775807_1598 = torch.constant.int 9223372036854775807 + %int1_1599 = torch.constant.int 1 + %1552 = torch.aten.slice.Tensor %1551, %int3_1596, %int0_1597, %int9223372036854775807_1598, %int1_1599 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1552, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_1600 = torch.constant.int 4 %int1_1601 = torch.constant.int 1 - %1554 = torch.aten.unsqueeze %1545, %int1_1601 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_1602 = torch.constant.int 0 - %1555 = torch.aten.unsqueeze %1553, %int0_1602 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %1556 = torch.aten.mul.Tensor %1554, %1555 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %int1_1602 = torch.constant.int 1 %int1_1603 = torch.constant.int 1 - %1557 = torch.aten.size.int %1503, %int1_1603 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int + %1553 = torch.prim.ListConstruct %int4_1600, %int1_1601, %int1_1602, %int1_1603 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1554 = torch.aten.repeat %1552, %1553 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %1554, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> %int0_1604 = torch.constant.int 0 - %1558 = torch.aten.add.int %int0_1604, %1557 : !torch.int, !torch.int -> !torch.int - %int0_1605 = torch.constant.int 0 + %1555 = torch.aten.unsqueeze %1548, %int0_1604 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1555, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_1605 = torch.constant.int 1 %int0_1606 = torch.constant.int 0 - %int1_1607 = torch.constant.int 1 - %1559 = torch.aten.slice.Tensor %1556, %int0_1605, %int0_1606, %1558, %int1_1607 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1559, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %int9223372036854775807_1607 = torch.constant.int 9223372036854775807 %int1_1608 = torch.constant.int 1 - %int0_1609 = torch.constant.int 0 - %int9223372036854775807_1610 = torch.constant.int 9223372036854775807 - %int1_1611 = torch.constant.int 1 - %1560 = torch.aten.slice.Tensor %1559, %int1_1608, %int0_1609, %int9223372036854775807_1610, %int1_1611 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1560, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_1612 = torch.constant.int 1 - %int0_1613 = torch.constant.int 0 - %int9223372036854775807_1614 = torch.constant.int 9223372036854775807 + %1556 = torch.aten.slice.Tensor %1555, %int1_1605, %int0_1606, %int9223372036854775807_1607, %int1_1608 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1556, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_1609 = torch.constant.int 2 + %1557 = torch.aten.unsqueeze %1556, %int2_1609 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1557, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_1610 = torch.constant.int 3 + %int0_1611 = torch.constant.int 0 + %int9223372036854775807_1612 = torch.constant.int 9223372036854775807 + %int1_1613 = torch.constant.int 1 + %1558 = torch.aten.slice.Tensor %1557, %int3_1610, %int0_1611, %int9223372036854775807_1612, %int1_1613 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1558, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_1614 = torch.constant.int 4 %int1_1615 = torch.constant.int 1 - %1561 = torch.aten.slice.Tensor %1560, %int1_1612, %int0_1613, %int9223372036854775807_1614, %int1_1615 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1561, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_1616 = torch.constant.int 0 - %1562 = torch.aten.unsqueeze %1561, %int0_1616 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1562, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %int1_1616 = torch.constant.int 1 %int1_1617 = torch.constant.int 1 - %int0_1618 = torch.constant.int 0 - %int9223372036854775807_1619 = torch.constant.int 9223372036854775807 - %int1_1620 = torch.constant.int 1 - %1563 = torch.aten.slice.Tensor %1562, %int1_1617, %int0_1618, %int9223372036854775807_1619, %int1_1620 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1563, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_1621 = torch.constant.int 2 - %int0_1622 = torch.constant.int 0 - %int9223372036854775807_1623 = torch.constant.int 9223372036854775807 - %int1_1624 = torch.constant.int 1 - %1564 = torch.aten.slice.Tensor %1563, %int2_1621, %int0_1622, %int9223372036854775807_1623, %int1_1624 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1564, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_1625 = torch.constant.int 4 - %int1_1626 = torch.constant.int 1 + %1559 = torch.prim.ListConstruct %int4_1614, %int1_1615, %int1_1616, %int1_1617 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1560 = torch.aten.repeat %1558, %1559 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %1560, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %1561 = torch.aten.mul.Tensor %1440, %1554 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1561, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_1618 = torch.constant.int 3 + %int0_1619 = torch.constant.int 0 + %int64_1620 = torch.constant.int 64 + %int1_1621 = torch.constant.int 1 + %1562 = torch.aten.slice.Tensor %1440, %int3_1618, %int0_1619, %int64_1620, %int1_1621 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %1562, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_1622 = torch.constant.int 3 + %int64_1623 = torch.constant.int 64 + %int9223372036854775807_1624 = torch.constant.int 9223372036854775807 + %int1_1625 = torch.constant.int 1 + %1563 = torch.aten.slice.Tensor %1440, %int3_1622, %int64_1623, %int9223372036854775807_1624, %int1_1625 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %1563, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %1564 = torch.aten.neg %1563 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %1564, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %1565 = torch.prim.ListConstruct %1564, %1562 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_1626 = torch.constant.int -1 + %1566 = torch.aten.cat %1565, %int-1_1626 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1566, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %1567 = torch.aten.mul.Tensor %1566, %1560 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1567, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> %int1_1627 = torch.constant.int 1 - %1565 = torch.prim.ListConstruct %int4_1625, %int1_1626, %int1_1627 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1566 = torch.aten.repeat %1564, %1565 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %1566, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_1628 = torch.constant.int 6 - %1567 = torch.prims.convert_element_type %1514, %int6_1628 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %1567, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %1568 = torch_c.to_builtin_tensor %1567 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %1569 = torch_c.to_builtin_tensor %1566 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %1570 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%1568, %1569) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %1571 = torch_c.from_builtin_tensor %1570 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %1571, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_1629 = torch.constant.int 5 - %1572 = torch.prims.convert_element_type %1571, %int5_1629 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1572, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_1630 = torch.constant.int 64 - %1573 = torch.aten.mul.Scalar %arg2, %int64_1630 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1573, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int12 = torch.constant.int 12 - %int1_1631 = torch.constant.int 1 - %1574 = torch.aten.add.Scalar %1573, %int12, %int1_1631 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1574, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_1632 = torch.constant.int 4 - %int32_1633 = torch.constant.int 32 - %int8_1634 = torch.constant.int 8 - %int128_1635 = torch.constant.int 128 - %1575 = torch.prim.ListConstruct %int4_1632, %398, %int32_1633, %int8_1634, %int128_1635 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1576 = torch.aten.view %1572, %1575 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1576, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_1636 = torch.constant.int 4 - %1577 = torch.aten.mul.int %int4_1636, %398 : !torch.int, !torch.int -> !torch.int - %int32_1637 = torch.constant.int 32 - %int8_1638 = torch.constant.int 8 - %int128_1639 = torch.constant.int 128 - %1578 = torch.prim.ListConstruct %1577, %int32_1637, %int8_1638, %int128_1639 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1579 = torch.aten.view %1576, %1578 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1579, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_1640 = torch.constant.int 4 - %1580 = torch.aten.mul.int %int4_1640, %398 : !torch.int, !torch.int -> !torch.int - %1581 = torch.prim.ListConstruct %1580 : (!torch.int) -> !torch.list - %1582 = torch.aten.view %1574, %1581 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1582, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_1641 = torch.constant.int 32 + %1568 = torch.aten.add.Tensor %1561, %1567, %int1_1627 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1568, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_1628 = torch.constant.int 32 + %1569 = torch.aten.mul.Scalar %arg2, %int32_1628 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1569, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int4_1629 = torch.constant.int 4 + %int1_1630 = torch.constant.int 1 + %1570 = torch.aten.add.Scalar %1569, %int4_1629, %int1_1630 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1570, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_1631 = torch.constant.int 2 + %1571 = torch.aten.mul.Scalar %1570, %int2_1631 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1571, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_1632 = torch.constant.int 0 + %int1_1633 = torch.constant.int 1 + %1572 = torch.aten.add.Scalar %1571, %int0_1632, %int1_1633 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1572, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %1573 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %1574 = torch.aten.view %1572, %1573 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %1574, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_1634 = torch.constant.int 4 + %int32_1635 = torch.constant.int 32 + %int8_1636 = torch.constant.int 8 + %int128_1637 = torch.constant.int 128 + %1575 = torch.prim.ListConstruct %int4_1634, %296, %int32_1635, %int8_1636, %int128_1637 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1576 = torch.aten.view %1568, %1575 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1576, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_1638 = torch.constant.int 32 + %int8_1639 = torch.constant.int 8 + %int128_1640 = torch.constant.int 128 + %1577 = torch.prim.ListConstruct %504, %int32_1638, %int8_1639, %int128_1640 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1578 = torch.aten.view %1576, %1577 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %1578, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_1641 = torch.constant.int 1 %int2_1642 = torch.constant.int 2 - %int32_1643 = torch.constant.int 32 - %int8_1644 = torch.constant.int 8 - %int128_1645 = torch.constant.int 128 - %1583 = torch.prim.ListConstruct %389, %int32_1641, %int2_1642, %int32_1643, %int8_1644, %int128_1645 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1584 = torch.aten.view %1416, %1583 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1584, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_1646 = torch.constant.int 32 - %1585 = torch.aten.mul.int %389, %int32_1646 : !torch.int, !torch.int -> !torch.int - %int2_1647 = torch.constant.int 2 - %1586 = torch.aten.mul.int %1585, %int2_1647 : !torch.int, !torch.int -> !torch.int - %int32_1648 = torch.constant.int 32 + %1579 = torch.aten.transpose.int %1578, %int1_1641, %int2_1642 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1579, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_1643 = torch.constant.int 5 + %1580 = torch.prims.convert_element_type %1579, %int5_1643 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1580, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_1644 = torch.constant.int 32 + %int2_1645 = torch.constant.int 2 + %int8_1646 = torch.constant.int 8 + %int32_1647 = torch.constant.int 32 + %int128_1648 = torch.constant.int 128 + %1581 = torch.prim.ListConstruct %297, %int32_1644, %int2_1645, %int8_1646, %int32_1647, %int128_1648 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1582 = torch.aten.view %1344, %1581 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1582, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> %int8_1649 = torch.constant.int 8 - %int128_1650 = torch.constant.int 128 - %1587 = torch.prim.ListConstruct %1586, %int32_1648, %int8_1649, %int128_1650 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1588 = torch.aten.view %1584, %1587 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1588, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %1589 = torch.prim.ListConstruct %1582 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_1651 = torch.constant.bool false - %1590 = torch.aten.index_put %1588, %1589, %1579, %false_1651 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1590, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_1652 = torch.constant.int 32 - %int2_1653 = torch.constant.int 2 - %int32_1654 = torch.constant.int 32 + %int32_1650 = torch.constant.int 32 + %int128_1651 = torch.constant.int 128 + %1583 = torch.prim.ListConstruct %497, %int8_1649, %int32_1650, %int128_1651 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1584 = torch.aten.view %1582, %1583 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1584, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %1585 = torch.prim.ListConstruct %1574 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_1652 = torch.constant.bool false + %1586 = torch.aten.index_put %1584, %1585, %1580, %false_1652 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1586, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_1653 = torch.constant.int 32 + %int2_1654 = torch.constant.int 2 %int8_1655 = torch.constant.int 8 - %int128_1656 = torch.constant.int 128 - %1591 = torch.prim.ListConstruct %389, %int32_1652, %int2_1653, %int32_1654, %int8_1655, %int128_1656 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1592 = torch.aten.view %1590, %1591 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1592, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_1657 = torch.constant.int 2097152 - %1593 = torch.prim.ListConstruct %389, %int2097152_1657 : (!torch.int, !torch.int) -> !torch.list - %1594 = torch.aten.view %1592, %1593 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1594, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_1658 = torch.constant.int 32 - %int2_1659 = torch.constant.int 2 - %int32_1660 = torch.constant.int 32 + %int32_1656 = torch.constant.int 32 + %int128_1657 = torch.constant.int 128 + %1587 = torch.prim.ListConstruct %297, %int32_1653, %int2_1654, %int8_1655, %int32_1656, %int128_1657 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1588 = torch.aten.view %1586, %1587 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1588, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_1658 = torch.constant.int 2097152 + %1589 = torch.prim.ListConstruct %297, %int2097152_1658 : (!torch.int, !torch.int) -> !torch.list + %1590 = torch.aten.view %1588, %1589 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1590, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_1659 = torch.constant.int 32 + %int2_1660 = torch.constant.int 2 %int8_1661 = torch.constant.int 8 - %int128_1662 = torch.constant.int 128 - %1595 = torch.prim.ListConstruct %389, %int32_1658, %int2_1659, %int32_1660, %int8_1661, %int128_1662 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1596 = torch.aten.view %1594, %1595 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1596, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_1663 = torch.constant.int 32 + %int32_1662 = torch.constant.int 32 + %int128_1663 = torch.constant.int 128 + %1591 = torch.prim.ListConstruct %297, %int32_1659, %int2_1660, %int8_1661, %int32_1662, %int128_1663 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1592 = torch.aten.view %1590, %1591 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1592, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> %int8_1664 = torch.constant.int 8 - %int128_1665 = torch.constant.int 128 - %1597 = torch.prim.ListConstruct %1586, %int32_1663, %int8_1664, %int128_1665 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1598 = torch.aten.view %1596, %1597 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1598, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_1666 = torch.constant.int 4 + %int32_1665 = torch.constant.int 32 + %int128_1666 = torch.constant.int 128 + %1593 = torch.prim.ListConstruct %497, %int8_1664, %int32_1665, %int128_1666 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1594 = torch.aten.view %1592, %1593 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1594, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> %int32_1667 = torch.constant.int 32 - %int8_1668 = torch.constant.int 8 - %int128_1669 = torch.constant.int 128 - %1599 = torch.prim.ListConstruct %int4_1666, %398, %int32_1667, %int8_1668, %int128_1669 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1600 = torch.aten.view %1516, %1599 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1600, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_1670 = torch.constant.int 4 - %1601 = torch.aten.mul.int %int4_1670, %398 : !torch.int, !torch.int -> !torch.int - %int32_1671 = torch.constant.int 32 - %int8_1672 = torch.constant.int 8 - %int128_1673 = torch.constant.int 128 - %1602 = torch.prim.ListConstruct %1601, %int32_1671, %int8_1672, %int128_1673 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1603 = torch.aten.view %1600, %1602 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1603, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_1674 = torch.constant.int 1 - %int1_1675 = torch.constant.int 1 - %1604 = torch.aten.add.Scalar %1574, %int1_1674, %int1_1675 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1604, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_1676 = torch.constant.int 4 - %1605 = torch.aten.mul.int %int4_1676, %398 : !torch.int, !torch.int -> !torch.int - %1606 = torch.prim.ListConstruct %1605 : (!torch.int) -> !torch.list - %1607 = torch.aten.view %1604, %1606 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1607, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %1608 = torch.prim.ListConstruct %1607 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_1677 = torch.constant.bool false - %1609 = torch.aten.index_put %1598, %1608, %1603, %false_1677 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1609, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_1678 = torch.constant.int 32 - %int2_1679 = torch.constant.int 2 - %int32_1680 = torch.constant.int 32 - %int8_1681 = torch.constant.int 8 - %int128_1682 = torch.constant.int 128 - %1610 = torch.prim.ListConstruct %389, %int32_1678, %int2_1679, %int32_1680, %int8_1681, %int128_1682 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1611 = torch.aten.view %1609, %1610 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1611, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_1683 = torch.constant.int 2097152 - %1612 = torch.prim.ListConstruct %389, %int2097152_1683 : (!torch.int, !torch.int) -> !torch.list - %1613 = torch.aten.view %1611, %1612 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1613, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_1684 = torch.constant.int -2 - %1614 = torch.aten.unsqueeze %1572, %int-2_1684 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1614, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_1685 = torch.constant.int 4 + %1595 = torch.aten.mul.Scalar %arg2, %int32_1667 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1595, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int4_1668 = torch.constant.int 4 + %int1_1669 = torch.constant.int 1 + %1596 = torch.aten.add.Scalar %1595, %int4_1668, %int1_1669 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1596, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_1670 = torch.constant.int 2 + %1597 = torch.aten.mul.Scalar %1596, %int2_1670 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1597, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_1671 = torch.constant.int 1 + %int1_1672 = torch.constant.int 1 + %1598 = torch.aten.add.Scalar %1597, %int1_1671, %int1_1672 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1598, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %1599 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %1600 = torch.aten.view %1598, %1599 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %1600, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_1673 = torch.constant.int 4 + %int32_1674 = torch.constant.int 32 + %int8_1675 = torch.constant.int 8 + %int128_1676 = torch.constant.int 128 + %1601 = torch.prim.ListConstruct %int4_1673, %296, %int32_1674, %int8_1675, %int128_1676 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1602 = torch.aten.view %1442, %1601 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1602, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_1677 = torch.constant.int 32 + %int8_1678 = torch.constant.int 8 + %int128_1679 = torch.constant.int 128 + %1603 = torch.prim.ListConstruct %504, %int32_1677, %int8_1678, %int128_1679 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1604 = torch.aten.view %1602, %1603 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %1604, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_1680 = torch.constant.int 1 + %int2_1681 = torch.constant.int 2 + %1605 = torch.aten.transpose.int %1604, %int1_1680, %int2_1681 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1605, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_1682 = torch.constant.int 5 + %1606 = torch.prims.convert_element_type %1605, %int5_1682 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1606, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %1607 = torch.prim.ListConstruct %1600 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_1683 = torch.constant.bool false + %1608 = torch.aten.index_put %1594, %1607, %1606, %false_1683 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1608, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_1684 = torch.constant.int 32 + %int2_1685 = torch.constant.int 2 %int8_1686 = torch.constant.int 8 - %int4_1687 = torch.constant.int 4 + %int32_1687 = torch.constant.int 32 %int128_1688 = torch.constant.int 128 - %1615 = torch.prim.ListConstruct %int4_1685, %1557, %int8_1686, %int4_1687, %int128_1688 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1689 = torch.constant.bool false - %1616 = torch.aten.expand %1614, %1615, %false_1689 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1616, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1690 = torch.constant.int 0 - %1617 = torch.aten.clone %1616, %int0_1690 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1617, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %1609 = torch.prim.ListConstruct %297, %int32_1684, %int2_1685, %int8_1686, %int32_1687, %int128_1688 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1610 = torch.aten.view %1608, %1609 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1610, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_1689 = torch.constant.int 2097152 + %1611 = torch.prim.ListConstruct %297, %int2097152_1689 : (!torch.int, !torch.int) -> !torch.list + %1612 = torch.aten.view %1610, %1611 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1612, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_1690 = torch.constant.int -2 + %1613 = torch.aten.unsqueeze %1568, %int-2_1690 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1613, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> %int4_1691 = torch.constant.int 4 - %int32_1692 = torch.constant.int 32 - %int128_1693 = torch.constant.int 128 - %1618 = torch.prim.ListConstruct %int4_1691, %1557, %int32_1692, %int128_1693 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1619 = torch.aten._unsafe_view %1617, %1618 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1619, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_1694 = torch.constant.int -2 - %1620 = torch.aten.unsqueeze %1516, %int-2_1694 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1620, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_1695 = torch.constant.int 1 - %1621 = torch.aten.size.int %1510, %int1_1695 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_1696 = torch.constant.int 4 - %int8_1697 = torch.constant.int 8 - %int4_1698 = torch.constant.int 4 + %int8_1692 = torch.constant.int 8 + %int4_1693 = torch.constant.int 4 + %int128_1694 = torch.constant.int 128 + %1614 = torch.prim.ListConstruct %int4_1691, %298, %int8_1692, %int4_1693, %int128_1694 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_1695 = torch.constant.bool false + %1615 = torch.aten.expand %1613, %1614, %false_1695 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1615, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_1696 = torch.constant.int 0 + %1616 = torch.aten.clone %1615, %int0_1696 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1616, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_1697 = torch.constant.int 4 + %int32_1698 = torch.constant.int 32 %int128_1699 = torch.constant.int 128 - %1622 = torch.prim.ListConstruct %int4_1696, %1621, %int8_1697, %int4_1698, %int128_1699 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1700 = torch.constant.bool false - %1623 = torch.aten.expand %1620, %1622, %false_1700 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1623, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1701 = torch.constant.int 0 - %1624 = torch.aten.clone %1623, %int0_1701 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1624, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_1702 = torch.constant.int 4 - %int32_1703 = torch.constant.int 32 + %1617 = torch.prim.ListConstruct %int4_1697, %298, %int32_1698, %int128_1699 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1618 = torch.aten._unsafe_view %1616, %1617 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1618, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_1700 = torch.constant.int -2 + %1619 = torch.aten.unsqueeze %1442, %int-2_1700 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1619, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_1701 = torch.constant.int 4 + %int8_1702 = torch.constant.int 8 + %int4_1703 = torch.constant.int 4 %int128_1704 = torch.constant.int 128 - %1625 = torch.prim.ListConstruct %int4_1702, %1621, %int32_1703, %int128_1704 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1626 = torch.aten._unsafe_view %1624, %1625 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1626, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_1705 = torch.constant.int 1 - %int2_1706 = torch.constant.int 2 - %1627 = torch.aten.transpose.int %1544, %int1_1705, %int2_1706 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1627, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_1707 = torch.constant.int 1 - %int2_1708 = torch.constant.int 2 - %1628 = torch.aten.transpose.int %1619, %int1_1707, %int2_1708 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1628, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_1709 = torch.constant.int 1 - %int2_1710 = torch.constant.int 2 - %1629 = torch.aten.transpose.int %1626, %int1_1709, %int2_1710 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1629, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_1711 = torch.constant.float 0.000000e+00 - %true_1712 = torch.constant.bool true - %none_1713 = torch.constant.none - %none_1714 = torch.constant.none - %1630:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1627, %1628, %1629, %float0.000000e00_1711, %true_1712, %none_1713, %none_1714) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %1630#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_1715 = torch.constant.int 1 - %int2_1716 = torch.constant.int 2 - %1631 = torch.aten.transpose.int %1630#0, %int1_1715, %int2_1716 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1631, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_1717 = torch.constant.int 4 - %int4096_1718 = torch.constant.int 4096 - %1632 = torch.prim.ListConstruct %int4_1717, %1529, %int4096_1718 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1633 = torch.aten.view %1631, %1632 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1633, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1719 = torch.constant.int -2 - %int-1_1720 = torch.constant.int -1 - %1634 = torch.aten.transpose.int %59, %int-2_1719, %int-1_1720 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %1620 = torch.prim.ListConstruct %int4_1701, %298, %int8_1702, %int4_1703, %int128_1704 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_1705 = torch.constant.bool false + %1621 = torch.aten.expand %1619, %1620, %false_1705 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1621, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_1706 = torch.constant.int 0 + %1622 = torch.aten.clone %1621, %int0_1706 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1622, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_1707 = torch.constant.int 4 + %int32_1708 = torch.constant.int 32 + %int128_1709 = torch.constant.int 128 + %1623 = torch.prim.ListConstruct %int4_1707, %298, %int32_1708, %int128_1709 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1624 = torch.aten._unsafe_view %1622, %1623 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1624, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_1710 = torch.constant.int 1 + %int2_1711 = torch.constant.int 2 + %1625 = torch.aten.transpose.int %1505, %int1_1710, %int2_1711 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1625, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_1712 = torch.constant.int 1 + %int2_1713 = torch.constant.int 2 + %1626 = torch.aten.transpose.int %1618, %int1_1712, %int2_1713 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1626, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_1714 = torch.constant.int 1 + %int2_1715 = torch.constant.int 2 + %1627 = torch.aten.transpose.int %1624, %int1_1714, %int2_1715 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1627, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_1716 = torch.constant.float 0.000000e+00 + %false_1717 = torch.constant.bool false + %none_1718 = torch.constant.none + %1628:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1625, %1626, %1627, %float0.000000e00_1716, %false_1717, %327, %none_1718) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %1628#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_1719 = torch.constant.int 1 + %int2_1720 = torch.constant.int 2 + %1629 = torch.aten.transpose.int %1628#0, %int1_1719, %int2_1720 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1629, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int4_1721 = torch.constant.int 4 - %1635 = torch.aten.mul.int %int4_1721, %1529 : !torch.int, !torch.int -> !torch.int %int4096_1722 = torch.constant.int 4096 - %1636 = torch.prim.ListConstruct %1635, %int4096_1722 : (!torch.int, !torch.int) -> !torch.list - %1637 = torch.aten.view %1633, %1636 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1637, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1638 = torch.aten.mm %1637, %1634 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1638, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_1723 = torch.constant.int 4 - %int4096_1724 = torch.constant.int 4096 - %1639 = torch.prim.ListConstruct %int4_1723, %1529, %int4096_1724 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1640 = torch.aten.view %1638, %1639 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1640, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_1725 = torch.constant.int 1 - %1641 = torch.aten.add.Tensor %1479, %1640, %int1_1725 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1641, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_1726 = torch.constant.int 6 - %1642 = torch.prims.convert_element_type %1641, %int6_1726 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1642, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_1727 = torch.constant.int 2 - %1643 = torch.aten.pow.Tensor_Scalar %1642, %int2_1727 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1643, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_1728 = torch.constant.int -1 - %1644 = torch.prim.ListConstruct %int-1_1728 : (!torch.int) -> !torch.list - %true_1729 = torch.constant.bool true - %none_1730 = torch.constant.none - %1645 = torch.aten.mean.dim %1643, %1644, %true_1729, %none_1730 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1645, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_1731 = torch.constant.float 9.9999997473787516E-6 - %int1_1732 = torch.constant.int 1 - %1646 = torch.aten.add.Scalar %1645, %float9.999990e-06_1731, %int1_1732 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1646, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1647 = torch.aten.rsqrt %1646 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1647, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1648 = torch.aten.mul.Tensor %1642, %1647 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1648, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1733 = torch.constant.int 5 - %1649 = torch.prims.convert_element_type %1648, %int5_1733 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1649, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %1650 = torch.aten.mul.Tensor %60, %1649 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1650, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1734 = torch.constant.int 5 - %1651 = torch.prims.convert_element_type %1650, %int5_1734 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1651, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1735 = torch.constant.int -2 - %int-1_1736 = torch.constant.int -1 - %1652 = torch.aten.transpose.int %61, %int-2_1735, %int-1_1736 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1737 = torch.constant.int 4 - %1653 = torch.aten.mul.int %int4_1737, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1738 = torch.constant.int 4096 - %1654 = torch.prim.ListConstruct %1653, %int4096_1738 : (!torch.int, !torch.int) -> !torch.list - %1655 = torch.aten.view %1651, %1654 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1655, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1656 = torch.aten.mm %1655, %1652 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1656, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_1739 = torch.constant.int 4 - %int14336_1740 = torch.constant.int 14336 - %1657 = torch.prim.ListConstruct %int4_1739, %306, %int14336_1740 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1658 = torch.aten.view %1656, %1657 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1658, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %1659 = torch.aten.silu %1658 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1659, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_1741 = torch.constant.int -2 - %int-1_1742 = torch.constant.int -1 - %1660 = torch.aten.transpose.int %62, %int-2_1741, %int-1_1742 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %1630 = torch.prim.ListConstruct %int4_1721, %298, %int4096_1722 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1631 = torch.aten.view %1629, %1630 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1631, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_1723 = torch.constant.int -2 + %int-1_1724 = torch.constant.int -1 + %1632 = torch.aten.transpose.int %42, %int-2_1723, %int-1_1724 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_1725 = torch.constant.int 5 + %1633 = torch.prims.convert_element_type %1632, %int5_1725 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_1726 = torch.constant.int 4096 + %1634 = torch.prim.ListConstruct %342, %int4096_1726 : (!torch.int, !torch.int) -> !torch.list + %1635 = torch.aten.view %1631, %1634 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1635, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1636 = torch.aten.mm %1635, %1633 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1636, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_1727 = torch.constant.int 4 + %int4096_1728 = torch.constant.int 4096 + %1637 = torch.prim.ListConstruct %int4_1727, %298, %int4096_1728 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1638 = torch.aten.view %1636, %1637 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1638, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_1729 = torch.constant.int 1 + %1639 = torch.aten.add.Tensor %1405, %1638, %int1_1729 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1639, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_1730 = torch.constant.int 6 + %1640 = torch.prims.convert_element_type %1639, %int6_1730 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1640, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_1731 = torch.constant.int 2 + %1641 = torch.aten.pow.Tensor_Scalar %1640, %int2_1731 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1641, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_1732 = torch.constant.int -1 + %1642 = torch.prim.ListConstruct %int-1_1732 : (!torch.int) -> !torch.list + %true_1733 = torch.constant.bool true + %none_1734 = torch.constant.none + %1643 = torch.aten.mean.dim %1641, %1642, %true_1733, %none_1734 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1643, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_1735 = torch.constant.float 9.9999997473787516E-6 + %int1_1736 = torch.constant.int 1 + %1644 = torch.aten.add.Scalar %1643, %float9.999990e-06_1735, %int1_1736 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1644, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1645 = torch.aten.rsqrt %1644 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1645, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1646 = torch.aten.mul.Tensor %1640, %1645 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1646, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_1737 = torch.constant.int 5 + %1647 = torch.prims.convert_element_type %1646, %int5_1737 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1647, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %1648 = torch.aten.mul.Tensor %43, %1647 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1648, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_1738 = torch.constant.int 5 + %1649 = torch.prims.convert_element_type %1648, %int5_1738 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1649, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_1739 = torch.constant.int -2 + %int-1_1740 = torch.constant.int -1 + %1650 = torch.aten.transpose.int %44, %int-2_1739, %int-1_1740 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_1741 = torch.constant.int 5 + %1651 = torch.prims.convert_element_type %1650, %int5_1741 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_1742 = torch.constant.int 4096 + %1652 = torch.prim.ListConstruct %342, %int4096_1742 : (!torch.int, !torch.int) -> !torch.list + %1653 = torch.aten.view %1649, %1652 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1653, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1654 = torch.aten.mm %1653, %1651 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %1654, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> %int4_1743 = torch.constant.int 4 - %1661 = torch.aten.mul.int %int4_1743, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1744 = torch.constant.int 4096 - %1662 = torch.prim.ListConstruct %1661, %int4096_1744 : (!torch.int, !torch.int) -> !torch.list - %1663 = torch.aten.view %1651, %1662 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1663, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1664 = torch.aten.mm %1663, %1660 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1664, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_1745 = torch.constant.int 4 - %int14336_1746 = torch.constant.int 14336 - %1665 = torch.prim.ListConstruct %int4_1745, %306, %int14336_1746 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1666 = torch.aten.view %1664, %1665 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1666, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %1667 = torch.aten.mul.Tensor %1659, %1666 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1667, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_1747 = torch.constant.int -2 - %int-1_1748 = torch.constant.int -1 - %1668 = torch.aten.transpose.int %63, %int-2_1747, %int-1_1748 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_1749 = torch.constant.int 1 - %1669 = torch.aten.size.int %1658, %int1_1749 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_1750 = torch.constant.int 4 - %1670 = torch.aten.mul.int %int4_1750, %1669 : !torch.int, !torch.int -> !torch.int - %int14336_1751 = torch.constant.int 14336 - %1671 = torch.prim.ListConstruct %1670, %int14336_1751 : (!torch.int, !torch.int) -> !torch.list - %1672 = torch.aten.view %1667, %1671 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1672, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %1673 = torch.aten.mm %1672, %1668 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1673, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_1752 = torch.constant.int 4 - %int4096_1753 = torch.constant.int 4096 - %1674 = torch.prim.ListConstruct %int4_1752, %1669, %int4096_1753 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1675 = torch.aten.view %1673, %1674 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1675, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_1754 = torch.constant.int 1 - %1676 = torch.aten.add.Tensor %1641, %1675, %int1_1754 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1676, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_1755 = torch.constant.int 6 - %1677 = torch.prims.convert_element_type %1676, %int6_1755 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1677, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_1756 = torch.constant.int 2 - %1678 = torch.aten.pow.Tensor_Scalar %1677, %int2_1756 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1678, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_1757 = torch.constant.int -1 - %1679 = torch.prim.ListConstruct %int-1_1757 : (!torch.int) -> !torch.list - %true_1758 = torch.constant.bool true - %none_1759 = torch.constant.none - %1680 = torch.aten.mean.dim %1678, %1679, %true_1758, %none_1759 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1680, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_1760 = torch.constant.float 9.9999997473787516E-6 - %int1_1761 = torch.constant.int 1 - %1681 = torch.aten.add.Scalar %1680, %float9.999990e-06_1760, %int1_1761 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1681, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1682 = torch.aten.rsqrt %1681 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1682, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1683 = torch.aten.mul.Tensor %1677, %1682 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1683, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1762 = torch.constant.int 5 - %1684 = torch.prims.convert_element_type %1683, %int5_1762 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1684, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %1685 = torch.aten.mul.Tensor %64, %1684 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1685, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1763 = torch.constant.int 5 - %1686 = torch.prims.convert_element_type %1685, %int5_1763 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1686, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1764 = torch.constant.int -2 - %int-1_1765 = torch.constant.int -1 - %1687 = torch.aten.transpose.int %65, %int-2_1764, %int-1_1765 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1766 = torch.constant.int 4 - %1688 = torch.aten.mul.int %int4_1766, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1767 = torch.constant.int 4096 - %1689 = torch.prim.ListConstruct %1688, %int4096_1767 : (!torch.int, !torch.int) -> !torch.list - %1690 = torch.aten.view %1686, %1689 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1690, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1691 = torch.aten.mm %1690, %1687 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1691, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_1768 = torch.constant.int 4 - %int4096_1769 = torch.constant.int 4096 - %1692 = torch.prim.ListConstruct %int4_1768, %306, %int4096_1769 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1693 = torch.aten.view %1691, %1692 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1693, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1770 = torch.constant.int -2 - %int-1_1771 = torch.constant.int -1 - %1694 = torch.aten.transpose.int %66, %int-2_1770, %int-1_1771 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_1772 = torch.constant.int 4 - %1695 = torch.aten.mul.int %int4_1772, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1773 = torch.constant.int 4096 - %1696 = torch.prim.ListConstruct %1695, %int4096_1773 : (!torch.int, !torch.int) -> !torch.list - %1697 = torch.aten.view %1686, %1696 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1697, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1698 = torch.aten.mm %1697, %1694 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %1698, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_1774 = torch.constant.int 4 - %int1024_1775 = torch.constant.int 1024 - %1699 = torch.prim.ListConstruct %int4_1774, %306, %int1024_1775 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1700 = torch.aten.view %1698, %1699 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %1700, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_1776 = torch.constant.int -2 - %int-1_1777 = torch.constant.int -1 - %1701 = torch.aten.transpose.int %67, %int-2_1776, %int-1_1777 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_1778 = torch.constant.int 4 - %1702 = torch.aten.mul.int %int4_1778, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1779 = torch.constant.int 4096 - %1703 = torch.prim.ListConstruct %1702, %int4096_1779 : (!torch.int, !torch.int) -> !torch.list - %1704 = torch.aten.view %1686, %1703 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1704, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1705 = torch.aten.mm %1704, %1701 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %1705, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_1780 = torch.constant.int 4 - %int1024_1781 = torch.constant.int 1024 - %1706 = torch.prim.ListConstruct %int4_1780, %306, %int1024_1781 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1707 = torch.aten.view %1705, %1706 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %1707, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_1782 = torch.constant.int 4 - %int32_1783 = torch.constant.int 32 - %int128_1784 = torch.constant.int 128 - %1708 = torch.prim.ListConstruct %int4_1782, %306, %int32_1783, %int128_1784 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1709 = torch.aten.view %1693, %1708 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1709, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int14336_1744 = torch.constant.int 14336 + %1655 = torch.prim.ListConstruct %int4_1743, %298, %int14336_1744 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1656 = torch.aten.view %1654, %1655 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1656, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %1657 = torch.aten.silu %1656 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1657, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_1745 = torch.constant.int -2 + %int-1_1746 = torch.constant.int -1 + %1658 = torch.aten.transpose.int %45, %int-2_1745, %int-1_1746 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_1747 = torch.constant.int 5 + %1659 = torch.prims.convert_element_type %1658, %int5_1747 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_1748 = torch.constant.int 4096 + %1660 = torch.prim.ListConstruct %342, %int4096_1748 : (!torch.int, !torch.int) -> !torch.list + %1661 = torch.aten.view %1649, %1660 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1661, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1662 = torch.aten.mm %1661, %1659 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %1662, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_1749 = torch.constant.int 4 + %int14336_1750 = torch.constant.int 14336 + %1663 = torch.prim.ListConstruct %int4_1749, %298, %int14336_1750 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1664 = torch.aten.view %1662, %1663 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1664, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %1665 = torch.aten.mul.Tensor %1657, %1664 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1665, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_1751 = torch.constant.int -2 + %int-1_1752 = torch.constant.int -1 + %1666 = torch.aten.transpose.int %46, %int-2_1751, %int-1_1752 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_1753 = torch.constant.int 5 + %1667 = torch.prims.convert_element_type %1666, %int5_1753 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_1754 = torch.constant.int 14336 + %1668 = torch.prim.ListConstruct %342, %int14336_1754 : (!torch.int, !torch.int) -> !torch.list + %1669 = torch.aten.view %1665, %1668 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %1669, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %1670 = torch.aten.mm %1669, %1667 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1670, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_1755 = torch.constant.int 4 + %int4096_1756 = torch.constant.int 4096 + %1671 = torch.prim.ListConstruct %int4_1755, %298, %int4096_1756 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1672 = torch.aten.view %1670, %1671 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1672, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_1757 = torch.constant.int 1 + %1673 = torch.aten.add.Tensor %1639, %1672, %int1_1757 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1673, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_1758 = torch.constant.int 6 + %1674 = torch.prims.convert_element_type %1673, %int6_1758 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1674, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_1759 = torch.constant.int 2 + %1675 = torch.aten.pow.Tensor_Scalar %1674, %int2_1759 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1675, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_1760 = torch.constant.int -1 + %1676 = torch.prim.ListConstruct %int-1_1760 : (!torch.int) -> !torch.list + %true_1761 = torch.constant.bool true + %none_1762 = torch.constant.none + %1677 = torch.aten.mean.dim %1675, %1676, %true_1761, %none_1762 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1677, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_1763 = torch.constant.float 9.9999997473787516E-6 + %int1_1764 = torch.constant.int 1 + %1678 = torch.aten.add.Scalar %1677, %float9.999990e-06_1763, %int1_1764 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1678, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1679 = torch.aten.rsqrt %1678 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1679, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1680 = torch.aten.mul.Tensor %1674, %1679 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1680, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_1765 = torch.constant.int 5 + %1681 = torch.prims.convert_element_type %1680, %int5_1765 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1681, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %1682 = torch.aten.mul.Tensor %47, %1681 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1682, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_1766 = torch.constant.int 5 + %1683 = torch.prims.convert_element_type %1682, %int5_1766 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1683, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_1767 = torch.constant.int -2 + %int-1_1768 = torch.constant.int -1 + %1684 = torch.aten.transpose.int %48, %int-2_1767, %int-1_1768 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_1769 = torch.constant.int 5 + %1685 = torch.prims.convert_element_type %1684, %int5_1769 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_1770 = torch.constant.int 4096 + %1686 = torch.prim.ListConstruct %342, %int4096_1770 : (!torch.int, !torch.int) -> !torch.list + %1687 = torch.aten.view %1683, %1686 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1687, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1688 = torch.aten.mm %1687, %1685 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1688, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_1771 = torch.constant.int 4 + %int4096_1772 = torch.constant.int 4096 + %1689 = torch.prim.ListConstruct %int4_1771, %298, %int4096_1772 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1690 = torch.aten.view %1688, %1689 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1690, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_1773 = torch.constant.int -2 + %int-1_1774 = torch.constant.int -1 + %1691 = torch.aten.transpose.int %49, %int-2_1773, %int-1_1774 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_1775 = torch.constant.int 5 + %1692 = torch.prims.convert_element_type %1691, %int5_1775 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_1776 = torch.constant.int 4096 + %1693 = torch.prim.ListConstruct %342, %int4096_1776 : (!torch.int, !torch.int) -> !torch.list + %1694 = torch.aten.view %1683, %1693 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1694, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1695 = torch.aten.mm %1694, %1692 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %1695, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_1777 = torch.constant.int 4 + %int1024_1778 = torch.constant.int 1024 + %1696 = torch.prim.ListConstruct %int4_1777, %298, %int1024_1778 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1697 = torch.aten.view %1695, %1696 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %1697, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_1779 = torch.constant.int -2 + %int-1_1780 = torch.constant.int -1 + %1698 = torch.aten.transpose.int %50, %int-2_1779, %int-1_1780 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_1781 = torch.constant.int 5 + %1699 = torch.prims.convert_element_type %1698, %int5_1781 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_1782 = torch.constant.int 4096 + %1700 = torch.prim.ListConstruct %342, %int4096_1782 : (!torch.int, !torch.int) -> !torch.list + %1701 = torch.aten.view %1683, %1700 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1701, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1702 = torch.aten.mm %1701, %1699 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %1702, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_1783 = torch.constant.int 4 + %int1024_1784 = torch.constant.int 1024 + %1703 = torch.prim.ListConstruct %int4_1783, %298, %int1024_1784 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1704 = torch.aten.view %1702, %1703 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %1704, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> %int4_1785 = torch.constant.int 4 - %int8_1786 = torch.constant.int 8 + %int32_1786 = torch.constant.int 32 %int128_1787 = torch.constant.int 128 - %1710 = torch.prim.ListConstruct %int4_1785, %306, %int8_1786, %int128_1787 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1711 = torch.aten.view %1700, %1710 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1711, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %1705 = torch.prim.ListConstruct %int4_1785, %298, %int32_1786, %int128_1787 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1706 = torch.aten.view %1690, %1705 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1706, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int4_1788 = torch.constant.int 4 %int8_1789 = torch.constant.int 8 %int128_1790 = torch.constant.int 128 - %1712 = torch.prim.ListConstruct %int4_1788, %306, %int8_1789, %int128_1790 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1713 = torch.aten.view %1707, %1712 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1713, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_1791 = torch.constant.int 131072 - %none_1792 = torch.constant.none - %none_1793 = torch.constant.none - %cpu_1794 = torch.constant.device "cpu" - %false_1795 = torch.constant.bool false - %1714 = torch.aten.arange %int131072_1791, %none_1792, %none_1793, %cpu_1794, %false_1795 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_1796 = torch.constant.int 0 - %int128_1797 = torch.constant.int 128 - %none_1798 = torch.constant.none - %none_1799 = torch.constant.none - %cpu_1800 = torch.constant.device "cpu" - %false_1801 = torch.constant.bool false - %1715 = torch.aten.arange.start %int0_1796, %int128_1797, %none_1798, %none_1799, %cpu_1800, %false_1801 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_1802 = torch.constant.int 2 - %1716 = torch.aten.floor_divide.Scalar %1715, %int2_1802 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_1803 = torch.constant.int 6 - %1717 = torch.prims.convert_element_type %1716, %int6_1803 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_1804 = torch.constant.int 128 - %1718 = torch.aten.div.Scalar %1717, %int128_1804 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_1805 = torch.constant.float 2.000000e+00 - %1719 = torch.aten.mul.Scalar %1718, %float2.000000e00_1805 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_1806 = torch.constant.float 5.000000e+05 - %1720 = torch.aten.pow.Scalar %float5.000000e05_1806, %1719 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %1721 = torch.aten.reciprocal %1720 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_1807 = torch.constant.float 1.000000e+00 - %1722 = torch.aten.mul.Scalar %1721, %float1.000000e00_1807 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_1808 = torch.constant.int 1 - %1723 = torch.aten.unsqueeze %1714, %int1_1808 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_1809 = torch.constant.int 0 - %1724 = torch.aten.unsqueeze %1722, %int0_1809 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %1725 = torch.aten.mul.Tensor %1723, %1724 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_1810 = torch.constant.int 1 - %1726 = torch.aten.size.int %1693, %int1_1810 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_1811 = torch.constant.int 0 - %1727 = torch.aten.add.int %int0_1811, %1726 : !torch.int, !torch.int -> !torch.int - %int0_1812 = torch.constant.int 0 - %int0_1813 = torch.constant.int 0 + %1707 = torch.prim.ListConstruct %int4_1788, %298, %int8_1789, %int128_1790 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1708 = torch.aten.view %1697, %1707 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1708, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_1791 = torch.constant.int 4 + %int8_1792 = torch.constant.int 8 + %int128_1793 = torch.constant.int 128 + %1709 = torch.prim.ListConstruct %int4_1791, %298, %int8_1792, %int128_1793 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1710 = torch.aten.view %1704, %1709 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1710, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_1794 = torch.constant.int 131072 + %none_1795 = torch.constant.none + %none_1796 = torch.constant.none + %cpu_1797 = torch.constant.device "cpu" + %false_1798 = torch.constant.bool false + %1711 = torch.aten.arange %int131072_1794, %none_1795, %none_1796, %cpu_1797, %false_1798 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_1799 = torch.constant.int 0 + %int128_1800 = torch.constant.int 128 + %int2_1801 = torch.constant.int 2 + %int4_1802 = torch.constant.int 4 + %none_1803 = torch.constant.none + %cpu_1804 = torch.constant.device "cpu" + %false_1805 = torch.constant.bool false + %1712 = torch.aten.arange.start_step %int0_1799, %int128_1800, %int2_1801, %int4_1802, %none_1803, %cpu_1804, %false_1805 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_1806 = torch.constant.int 6 + %1713 = torch.prims.convert_element_type %1712, %int6_1806 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_1807 = torch.constant.int 128 + %1714 = torch.aten.div.Scalar %1713, %int128_1807 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_1808 = torch.constant.float 5.000000e+05 + %1715 = torch.aten.pow.Scalar %float5.000000e05_1808, %1714 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1716 = torch.aten.reciprocal %1715 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_1809 = torch.constant.float 1.000000e+00 + %1717 = torch.aten.mul.Scalar %1716, %float1.000000e00_1809 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %1718 = torch.aten.reciprocal %1717 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_1810 = torch.constant.float 6.2831853071795862 + %1719 = torch.aten.mul.Scalar %1718, %float6.283190e00_1810 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_1811 = torch.constant.float 8.192000e+03 + %1720 = torch.aten.gt.Scalar %1719, %float8.192000e03_1811 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_1812 = torch.constant.int 8 + %1721 = torch.aten.div.Scalar %1717, %int8_1812 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %1722 = torch.aten.where.self %1720, %1721, %1717 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1723 = torch.aten.reciprocal %1719 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_1813 = torch.constant.int 8192 + %1724 = torch.aten.mul.Scalar %1723, %int8192_1813 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_1814 = torch.constant.int 1 - %1728 = torch.aten.slice.Tensor %1725, %int0_1812, %int0_1813, %1727, %int1_1814 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1728, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> %int1_1815 = torch.constant.int 1 - %int0_1816 = torch.constant.int 0 - %int9223372036854775807_1817 = torch.constant.int 9223372036854775807 + %1725 = torch.aten.sub.Scalar %1724, %int1_1814, %int1_1815 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_1816 = torch.constant.int 3 + %1726 = torch.aten.div.Scalar %1725, %int3_1816 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_1817 = torch.constant.int 1 %int1_1818 = torch.constant.int 1 - %1729 = torch.aten.slice.Tensor %1728, %int1_1815, %int0_1816, %int9223372036854775807_1817, %int1_1818 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1729, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_1819 = torch.constant.int 1 - %int0_1820 = torch.constant.int 0 - %int9223372036854775807_1821 = torch.constant.int 9223372036854775807 - %int1_1822 = torch.constant.int 1 - %1730 = torch.aten.slice.Tensor %1729, %int1_1819, %int0_1820, %int9223372036854775807_1821, %int1_1822 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1730, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_1823 = torch.constant.int 0 - %1731 = torch.aten.unsqueeze %1730, %int0_1823 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1731, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_1824 = torch.constant.int 1 - %int0_1825 = torch.constant.int 0 - %int9223372036854775807_1826 = torch.constant.int 9223372036854775807 - %int1_1827 = torch.constant.int 1 - %1732 = torch.aten.slice.Tensor %1731, %int1_1824, %int0_1825, %int9223372036854775807_1826, %int1_1827 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1732, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_1828 = torch.constant.int 2 - %int0_1829 = torch.constant.int 0 - %int9223372036854775807_1830 = torch.constant.int 9223372036854775807 - %int1_1831 = torch.constant.int 1 - %1733 = torch.aten.slice.Tensor %1732, %int2_1828, %int0_1829, %int9223372036854775807_1830, %int1_1831 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1733, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_1832 = torch.constant.int 4 + %1727 = torch.aten.rsub.Scalar %1726, %int1_1817, %int1_1818 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %1728 = torch.aten.mul.Tensor %1727, %1722 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_1819 = torch.constant.int 8 + %1729 = torch.aten.div.Scalar %1728, %int8_1819 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %1730 = torch.aten.mul.Tensor %1726, %1722 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_1820 = torch.constant.int 1 + %1731 = torch.aten.add.Tensor %1729, %1730, %int1_1820 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_1821 = torch.constant.float 2.048000e+03 + %1732 = torch.aten.lt.Scalar %1719, %float2.048000e03_1821 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %1733 = torch.aten.bitwise_not %1732 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_1822 = torch.constant.float 8.192000e+03 + %1734 = torch.aten.gt.Scalar %1719, %float8.192000e03_1822 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %1735 = torch.aten.bitwise_not %1734 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %1736 = torch.aten.mul.Tensor %1733, %1735 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %1737 = torch.aten.where.self %1736, %1731, %1722 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1738 = torch.prim.ListConstruct %1737, %1737 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_1823 = torch.constant.int -1 + %1739 = torch.aten.cat %1738, %int-1_1823 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_1824 = torch.constant.int 6 + %1740 = torch.prims.convert_element_type %1739, %int6_1824 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_1825 = torch.constant.int 1 + %1741 = torch.aten.unsqueeze %1711, %int1_1825 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_1826 = torch.constant.int 6 + %1742 = torch.prims.convert_element_type %1741, %int6_1826 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_1827 = torch.constant.int 0 + %1743 = torch.aten.unsqueeze %1740, %int0_1827 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_1828 = torch.constant.int 6 + %1744 = torch.prims.convert_element_type %1743, %int6_1828 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %1745 = torch.aten.mul.Tensor %1742, %1744 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %1746 = torch.aten.cos %1745 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_1829 = torch.constant.int 5 + %1747 = torch.prims.convert_element_type %1746, %int5_1829 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %1748 = torch.aten.sin %1745 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_1830 = torch.constant.int 5 + %1749 = torch.prims.convert_element_type %1748, %int5_1830 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_1831 = torch.constant.int 0 + %int0_1832 = torch.constant.int 0 %int1_1833 = torch.constant.int 1 + %1750 = torch.aten.slice.Tensor %1747, %int0_1831, %int0_1832, %298, %int1_1833 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1750, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_1834 = torch.constant.int 1 - %1734 = torch.prim.ListConstruct %int4_1832, %int1_1833, %int1_1834 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1735 = torch.aten.repeat %1733, %1734 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %1735, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_1835 = torch.constant.int 6 - %1736 = torch.prims.convert_element_type %1709, %int6_1835 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %1736, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %1737 = torch_c.to_builtin_tensor %1736 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %1738 = torch_c.to_builtin_tensor %1735 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %1739 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%1737, %1738) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %1740 = torch_c.from_builtin_tensor %1739 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %1740, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_1836 = torch.constant.int 5 - %1741 = torch.prims.convert_element_type %1740, %int5_1836 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1741, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_1837 = torch.constant.int 131072 - %none_1838 = torch.constant.none - %none_1839 = torch.constant.none - %cpu_1840 = torch.constant.device "cpu" - %false_1841 = torch.constant.bool false - %1742 = torch.aten.arange %int131072_1837, %none_1838, %none_1839, %cpu_1840, %false_1841 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_1835 = torch.constant.int 0 + %int9223372036854775807_1836 = torch.constant.int 9223372036854775807 + %int1_1837 = torch.constant.int 1 + %1751 = torch.aten.slice.Tensor %1750, %int1_1834, %int0_1835, %int9223372036854775807_1836, %int1_1837 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1751, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_1838 = torch.constant.int 0 + %int0_1839 = torch.constant.int 0 + %int1_1840 = torch.constant.int 1 + %1752 = torch.aten.slice.Tensor %1749, %int0_1838, %int0_1839, %298, %int1_1840 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1752, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_1841 = torch.constant.int 1 %int0_1842 = torch.constant.int 0 - %int128_1843 = torch.constant.int 128 - %none_1844 = torch.constant.none - %none_1845 = torch.constant.none - %cpu_1846 = torch.constant.device "cpu" - %false_1847 = torch.constant.bool false - %1743 = torch.aten.arange.start %int0_1842, %int128_1843, %none_1844, %none_1845, %cpu_1846, %false_1847 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_1848 = torch.constant.int 2 - %1744 = torch.aten.floor_divide.Scalar %1743, %int2_1848 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_1849 = torch.constant.int 6 - %1745 = torch.prims.convert_element_type %1744, %int6_1849 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_1850 = torch.constant.int 128 - %1746 = torch.aten.div.Scalar %1745, %int128_1850 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_1851 = torch.constant.float 2.000000e+00 - %1747 = torch.aten.mul.Scalar %1746, %float2.000000e00_1851 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_1852 = torch.constant.float 5.000000e+05 - %1748 = torch.aten.pow.Scalar %float5.000000e05_1852, %1747 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %1749 = torch.aten.reciprocal %1748 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_1853 = torch.constant.float 1.000000e+00 - %1750 = torch.aten.mul.Scalar %1749, %float1.000000e00_1853 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %int9223372036854775807_1843 = torch.constant.int 9223372036854775807 + %int1_1844 = torch.constant.int 1 + %1753 = torch.aten.slice.Tensor %1752, %int1_1841, %int0_1842, %int9223372036854775807_1843, %int1_1844 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1753, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_1845 = torch.constant.int 0 + %1754 = torch.aten.unsqueeze %1751, %int0_1845 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1754, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_1846 = torch.constant.int 1 + %int0_1847 = torch.constant.int 0 + %int9223372036854775807_1848 = torch.constant.int 9223372036854775807 + %int1_1849 = torch.constant.int 1 + %1755 = torch.aten.slice.Tensor %1754, %int1_1846, %int0_1847, %int9223372036854775807_1848, %int1_1849 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1755, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_1850 = torch.constant.int 2 + %1756 = torch.aten.unsqueeze %1755, %int2_1850 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1756, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_1851 = torch.constant.int 3 + %int0_1852 = torch.constant.int 0 + %int9223372036854775807_1853 = torch.constant.int 9223372036854775807 %int1_1854 = torch.constant.int 1 - %1751 = torch.aten.unsqueeze %1742, %int1_1854 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_1855 = torch.constant.int 0 - %1752 = torch.aten.unsqueeze %1750, %int0_1855 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %1753 = torch.aten.mul.Tensor %1751, %1752 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %1757 = torch.aten.slice.Tensor %1756, %int3_1851, %int0_1852, %int9223372036854775807_1853, %int1_1854 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1757, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_1855 = torch.constant.int 4 %int1_1856 = torch.constant.int 1 - %1754 = torch.aten.size.int %1700, %int1_1856 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_1857 = torch.constant.int 0 - %1755 = torch.aten.add.int %int0_1857, %1754 : !torch.int, !torch.int -> !torch.int - %int0_1858 = torch.constant.int 0 + %int1_1857 = torch.constant.int 1 + %int1_1858 = torch.constant.int 1 + %1758 = torch.prim.ListConstruct %int4_1855, %int1_1856, %int1_1857, %int1_1858 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1759 = torch.aten.repeat %1757, %1758 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %1759, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> %int0_1859 = torch.constant.int 0 + %1760 = torch.aten.unsqueeze %1753, %int0_1859 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1760, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int1_1860 = torch.constant.int 1 - %1756 = torch.aten.slice.Tensor %1753, %int0_1858, %int0_1859, %1755, %int1_1860 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1756, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_1861 = torch.constant.int 1 - %int0_1862 = torch.constant.int 0 - %int9223372036854775807_1863 = torch.constant.int 9223372036854775807 - %int1_1864 = torch.constant.int 1 - %1757 = torch.aten.slice.Tensor %1756, %int1_1861, %int0_1862, %int9223372036854775807_1863, %int1_1864 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1757, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_1865 = torch.constant.int 1 + %int0_1861 = torch.constant.int 0 + %int9223372036854775807_1862 = torch.constant.int 9223372036854775807 + %int1_1863 = torch.constant.int 1 + %1761 = torch.aten.slice.Tensor %1760, %int1_1860, %int0_1861, %int9223372036854775807_1862, %int1_1863 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1761, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_1864 = torch.constant.int 2 + %1762 = torch.aten.unsqueeze %1761, %int2_1864 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1762, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_1865 = torch.constant.int 3 %int0_1866 = torch.constant.int 0 %int9223372036854775807_1867 = torch.constant.int 9223372036854775807 %int1_1868 = torch.constant.int 1 - %1758 = torch.aten.slice.Tensor %1757, %int1_1865, %int0_1866, %int9223372036854775807_1867, %int1_1868 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1758, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_1869 = torch.constant.int 0 - %1759 = torch.aten.unsqueeze %1758, %int0_1869 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1759, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %1763 = torch.aten.slice.Tensor %1762, %int3_1865, %int0_1866, %int9223372036854775807_1867, %int1_1868 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1763, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_1869 = torch.constant.int 4 %int1_1870 = torch.constant.int 1 - %int0_1871 = torch.constant.int 0 - %int9223372036854775807_1872 = torch.constant.int 9223372036854775807 - %int1_1873 = torch.constant.int 1 - %1760 = torch.aten.slice.Tensor %1759, %int1_1870, %int0_1871, %int9223372036854775807_1872, %int1_1873 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1760, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_1874 = torch.constant.int 2 - %int0_1875 = torch.constant.int 0 - %int9223372036854775807_1876 = torch.constant.int 9223372036854775807 - %int1_1877 = torch.constant.int 1 - %1761 = torch.aten.slice.Tensor %1760, %int2_1874, %int0_1875, %int9223372036854775807_1876, %int1_1877 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1761, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_1878 = torch.constant.int 4 - %int1_1879 = torch.constant.int 1 + %int1_1871 = torch.constant.int 1 + %int1_1872 = torch.constant.int 1 + %1764 = torch.prim.ListConstruct %int4_1869, %int1_1870, %int1_1871, %int1_1872 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1765 = torch.aten.repeat %1763, %1764 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %1765, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %1766 = torch.aten.mul.Tensor %1706, %1759 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1766, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_1873 = torch.constant.int 3 + %int0_1874 = torch.constant.int 0 + %int64_1875 = torch.constant.int 64 + %int1_1876 = torch.constant.int 1 + %1767 = torch.aten.slice.Tensor %1706, %int3_1873, %int0_1874, %int64_1875, %int1_1876 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %1767, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_1877 = torch.constant.int 3 + %int64_1878 = torch.constant.int 64 + %int9223372036854775807_1879 = torch.constant.int 9223372036854775807 %int1_1880 = torch.constant.int 1 - %1762 = torch.prim.ListConstruct %int4_1878, %int1_1879, %int1_1880 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1763 = torch.aten.repeat %1761, %1762 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %1763, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_1881 = torch.constant.int 6 - %1764 = torch.prims.convert_element_type %1711, %int6_1881 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %1764, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %1765 = torch_c.to_builtin_tensor %1764 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %1766 = torch_c.to_builtin_tensor %1763 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %1767 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%1765, %1766) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %1768 = torch_c.from_builtin_tensor %1767 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %1768, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_1882 = torch.constant.int 5 - %1769 = torch.prims.convert_element_type %1768, %int5_1882 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1769, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_1883 = torch.constant.int 64 - %1770 = torch.aten.mul.Scalar %arg2, %int64_1883 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1770, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int14 = torch.constant.int 14 - %int1_1884 = torch.constant.int 1 - %1771 = torch.aten.add.Scalar %1770, %int14, %int1_1884 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1771, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_1885 = torch.constant.int 4 - %int32_1886 = torch.constant.int 32 - %int8_1887 = torch.constant.int 8 - %int128_1888 = torch.constant.int 128 - %1772 = torch.prim.ListConstruct %int4_1885, %398, %int32_1886, %int8_1887, %int128_1888 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1773 = torch.aten.view %1769, %1772 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1773, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_1889 = torch.constant.int 4 - %1774 = torch.aten.mul.int %int4_1889, %398 : !torch.int, !torch.int -> !torch.int - %int32_1890 = torch.constant.int 32 - %int8_1891 = torch.constant.int 8 - %int128_1892 = torch.constant.int 128 - %1775 = torch.prim.ListConstruct %1774, %int32_1890, %int8_1891, %int128_1892 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1776 = torch.aten.view %1773, %1775 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1776, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_1893 = torch.constant.int 4 - %1777 = torch.aten.mul.int %int4_1893, %398 : !torch.int, !torch.int -> !torch.int - %1778 = torch.prim.ListConstruct %1777 : (!torch.int) -> !torch.list - %1779 = torch.aten.view %1771, %1778 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1779, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_1894 = torch.constant.int 32 - %int2_1895 = torch.constant.int 2 - %int32_1896 = torch.constant.int 32 - %int8_1897 = torch.constant.int 8 - %int128_1898 = torch.constant.int 128 - %1780 = torch.prim.ListConstruct %389, %int32_1894, %int2_1895, %int32_1896, %int8_1897, %int128_1898 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1781 = torch.aten.view %1613, %1780 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1781, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_1899 = torch.constant.int 32 - %1782 = torch.aten.mul.int %389, %int32_1899 : !torch.int, !torch.int -> !torch.int - %int2_1900 = torch.constant.int 2 - %1783 = torch.aten.mul.int %1782, %int2_1900 : !torch.int, !torch.int -> !torch.int - %int32_1901 = torch.constant.int 32 - %int8_1902 = torch.constant.int 8 - %int128_1903 = torch.constant.int 128 - %1784 = torch.prim.ListConstruct %1783, %int32_1901, %int8_1902, %int128_1903 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1785 = torch.aten.view %1781, %1784 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1785, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %1786 = torch.prim.ListConstruct %1779 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_1904 = torch.constant.bool false - %1787 = torch.aten.index_put %1785, %1786, %1776, %false_1904 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1787, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_1905 = torch.constant.int 32 - %int2_1906 = torch.constant.int 2 - %int32_1907 = torch.constant.int 32 + %1768 = torch.aten.slice.Tensor %1706, %int3_1877, %int64_1878, %int9223372036854775807_1879, %int1_1880 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %1768, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %1769 = torch.aten.neg %1768 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %1769, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %1770 = torch.prim.ListConstruct %1769, %1767 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_1881 = torch.constant.int -1 + %1771 = torch.aten.cat %1770, %int-1_1881 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1771, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %1772 = torch.aten.mul.Tensor %1771, %1765 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1772, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_1882 = torch.constant.int 1 + %1773 = torch.aten.add.Tensor %1766, %1772, %int1_1882 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1773, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_1883 = torch.constant.int 131072 + %none_1884 = torch.constant.none + %none_1885 = torch.constant.none + %cpu_1886 = torch.constant.device "cpu" + %false_1887 = torch.constant.bool false + %1774 = torch.aten.arange %int131072_1883, %none_1884, %none_1885, %cpu_1886, %false_1887 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_1888 = torch.constant.int 0 + %int128_1889 = torch.constant.int 128 + %int2_1890 = torch.constant.int 2 + %int4_1891 = torch.constant.int 4 + %none_1892 = torch.constant.none + %cpu_1893 = torch.constant.device "cpu" + %false_1894 = torch.constant.bool false + %1775 = torch.aten.arange.start_step %int0_1888, %int128_1889, %int2_1890, %int4_1891, %none_1892, %cpu_1893, %false_1894 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_1895 = torch.constant.int 6 + %1776 = torch.prims.convert_element_type %1775, %int6_1895 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_1896 = torch.constant.int 128 + %1777 = torch.aten.div.Scalar %1776, %int128_1896 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_1897 = torch.constant.float 5.000000e+05 + %1778 = torch.aten.pow.Scalar %float5.000000e05_1897, %1777 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1779 = torch.aten.reciprocal %1778 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_1898 = torch.constant.float 1.000000e+00 + %1780 = torch.aten.mul.Scalar %1779, %float1.000000e00_1898 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %1781 = torch.aten.reciprocal %1780 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_1899 = torch.constant.float 6.2831853071795862 + %1782 = torch.aten.mul.Scalar %1781, %float6.283190e00_1899 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_1900 = torch.constant.float 8.192000e+03 + %1783 = torch.aten.gt.Scalar %1782, %float8.192000e03_1900 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_1901 = torch.constant.int 8 + %1784 = torch.aten.div.Scalar %1780, %int8_1901 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %1785 = torch.aten.where.self %1783, %1784, %1780 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1786 = torch.aten.reciprocal %1782 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_1902 = torch.constant.int 8192 + %1787 = torch.aten.mul.Scalar %1786, %int8192_1902 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_1903 = torch.constant.int 1 + %int1_1904 = torch.constant.int 1 + %1788 = torch.aten.sub.Scalar %1787, %int1_1903, %int1_1904 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_1905 = torch.constant.int 3 + %1789 = torch.aten.div.Scalar %1788, %int3_1905 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_1906 = torch.constant.int 1 + %int1_1907 = torch.constant.int 1 + %1790 = torch.aten.rsub.Scalar %1789, %int1_1906, %int1_1907 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %1791 = torch.aten.mul.Tensor %1790, %1785 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> %int8_1908 = torch.constant.int 8 - %int128_1909 = torch.constant.int 128 - %1788 = torch.prim.ListConstruct %389, %int32_1905, %int2_1906, %int32_1907, %int8_1908, %int128_1909 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1789 = torch.aten.view %1787, %1788 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1789, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_1910 = torch.constant.int 2097152 - %1790 = torch.prim.ListConstruct %389, %int2097152_1910 : (!torch.int, !torch.int) -> !torch.list - %1791 = torch.aten.view %1789, %1790 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1791, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_1911 = torch.constant.int 32 - %int2_1912 = torch.constant.int 2 - %int32_1913 = torch.constant.int 32 - %int8_1914 = torch.constant.int 8 - %int128_1915 = torch.constant.int 128 - %1792 = torch.prim.ListConstruct %389, %int32_1911, %int2_1912, %int32_1913, %int8_1914, %int128_1915 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1793 = torch.aten.view %1791, %1792 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1793, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_1916 = torch.constant.int 32 - %int8_1917 = torch.constant.int 8 - %int128_1918 = torch.constant.int 128 - %1794 = torch.prim.ListConstruct %1783, %int32_1916, %int8_1917, %int128_1918 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1795 = torch.aten.view %1793, %1794 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1795, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_1919 = torch.constant.int 4 - %int32_1920 = torch.constant.int 32 - %int8_1921 = torch.constant.int 8 - %int128_1922 = torch.constant.int 128 - %1796 = torch.prim.ListConstruct %int4_1919, %398, %int32_1920, %int8_1921, %int128_1922 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1797 = torch.aten.view %1713, %1796 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1797, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_1923 = torch.constant.int 4 - %1798 = torch.aten.mul.int %int4_1923, %398 : !torch.int, !torch.int -> !torch.int - %int32_1924 = torch.constant.int 32 - %int8_1925 = torch.constant.int 8 - %int128_1926 = torch.constant.int 128 - %1799 = torch.prim.ListConstruct %1798, %int32_1924, %int8_1925, %int128_1926 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1800 = torch.aten.view %1797, %1799 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1800, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_1927 = torch.constant.int 1 - %int1_1928 = torch.constant.int 1 - %1801 = torch.aten.add.Scalar %1771, %int1_1927, %int1_1928 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1801, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_1929 = torch.constant.int 4 - %1802 = torch.aten.mul.int %int4_1929, %398 : !torch.int, !torch.int -> !torch.int - %1803 = torch.prim.ListConstruct %1802 : (!torch.int) -> !torch.list - %1804 = torch.aten.view %1801, %1803 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1804, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %1805 = torch.prim.ListConstruct %1804 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_1930 = torch.constant.bool false - %1806 = torch.aten.index_put %1795, %1805, %1800, %false_1930 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1806, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_1931 = torch.constant.int 32 - %int2_1932 = torch.constant.int 2 - %int32_1933 = torch.constant.int 32 - %int8_1934 = torch.constant.int 8 - %int128_1935 = torch.constant.int 128 - %1807 = torch.prim.ListConstruct %389, %int32_1931, %int2_1932, %int32_1933, %int8_1934, %int128_1935 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1808 = torch.aten.view %1806, %1807 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1808, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_1936 = torch.constant.int 2097152 - %1809 = torch.prim.ListConstruct %389, %int2097152_1936 : (!torch.int, !torch.int) -> !torch.list - %1810 = torch.aten.view %1808, %1809 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1810, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_1937 = torch.constant.int -2 - %1811 = torch.aten.unsqueeze %1769, %int-2_1937 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1811, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_1938 = torch.constant.int 4 - %int8_1939 = torch.constant.int 8 - %int4_1940 = torch.constant.int 4 - %int128_1941 = torch.constant.int 128 - %1812 = torch.prim.ListConstruct %int4_1938, %1754, %int8_1939, %int4_1940, %int128_1941 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1942 = torch.constant.bool false - %1813 = torch.aten.expand %1811, %1812, %false_1942 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1813, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1943 = torch.constant.int 0 - %1814 = torch.aten.clone %1813, %int0_1943 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1814, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %1792 = torch.aten.div.Scalar %1791, %int8_1908 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %1793 = torch.aten.mul.Tensor %1789, %1785 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_1909 = torch.constant.int 1 + %1794 = torch.aten.add.Tensor %1792, %1793, %int1_1909 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_1910 = torch.constant.float 2.048000e+03 + %1795 = torch.aten.lt.Scalar %1782, %float2.048000e03_1910 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %1796 = torch.aten.bitwise_not %1795 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_1911 = torch.constant.float 8.192000e+03 + %1797 = torch.aten.gt.Scalar %1782, %float8.192000e03_1911 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %1798 = torch.aten.bitwise_not %1797 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %1799 = torch.aten.mul.Tensor %1796, %1798 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %1800 = torch.aten.where.self %1799, %1794, %1785 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1801 = torch.prim.ListConstruct %1800, %1800 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_1912 = torch.constant.int -1 + %1802 = torch.aten.cat %1801, %int-1_1912 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_1913 = torch.constant.int 6 + %1803 = torch.prims.convert_element_type %1802, %int6_1913 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_1914 = torch.constant.int 1 + %1804 = torch.aten.unsqueeze %1774, %int1_1914 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_1915 = torch.constant.int 6 + %1805 = torch.prims.convert_element_type %1804, %int6_1915 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_1916 = torch.constant.int 0 + %1806 = torch.aten.unsqueeze %1803, %int0_1916 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_1917 = torch.constant.int 6 + %1807 = torch.prims.convert_element_type %1806, %int6_1917 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %1808 = torch.aten.mul.Tensor %1805, %1807 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %1809 = torch.aten.cos %1808 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_1918 = torch.constant.int 5 + %1810 = torch.prims.convert_element_type %1809, %int5_1918 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %1811 = torch.aten.sin %1808 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_1919 = torch.constant.int 5 + %1812 = torch.prims.convert_element_type %1811, %int5_1919 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_1920 = torch.constant.int 0 + %int0_1921 = torch.constant.int 0 + %int1_1922 = torch.constant.int 1 + %1813 = torch.aten.slice.Tensor %1810, %int0_1920, %int0_1921, %298, %int1_1922 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1813, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_1923 = torch.constant.int 1 + %int0_1924 = torch.constant.int 0 + %int9223372036854775807_1925 = torch.constant.int 9223372036854775807 + %int1_1926 = torch.constant.int 1 + %1814 = torch.aten.slice.Tensor %1813, %int1_1923, %int0_1924, %int9223372036854775807_1925, %int1_1926 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1814, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_1927 = torch.constant.int 0 + %int0_1928 = torch.constant.int 0 + %int1_1929 = torch.constant.int 1 + %1815 = torch.aten.slice.Tensor %1812, %int0_1927, %int0_1928, %298, %int1_1929 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1815, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_1930 = torch.constant.int 1 + %int0_1931 = torch.constant.int 0 + %int9223372036854775807_1932 = torch.constant.int 9223372036854775807 + %int1_1933 = torch.constant.int 1 + %1816 = torch.aten.slice.Tensor %1815, %int1_1930, %int0_1931, %int9223372036854775807_1932, %int1_1933 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1816, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_1934 = torch.constant.int 0 + %1817 = torch.aten.unsqueeze %1814, %int0_1934 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1817, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_1935 = torch.constant.int 1 + %int0_1936 = torch.constant.int 0 + %int9223372036854775807_1937 = torch.constant.int 9223372036854775807 + %int1_1938 = torch.constant.int 1 + %1818 = torch.aten.slice.Tensor %1817, %int1_1935, %int0_1936, %int9223372036854775807_1937, %int1_1938 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1818, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_1939 = torch.constant.int 2 + %1819 = torch.aten.unsqueeze %1818, %int2_1939 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1819, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_1940 = torch.constant.int 3 + %int0_1941 = torch.constant.int 0 + %int9223372036854775807_1942 = torch.constant.int 9223372036854775807 + %int1_1943 = torch.constant.int 1 + %1820 = torch.aten.slice.Tensor %1819, %int3_1940, %int0_1941, %int9223372036854775807_1942, %int1_1943 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1820, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> %int4_1944 = torch.constant.int 4 - %int32_1945 = torch.constant.int 32 - %int128_1946 = torch.constant.int 128 - %1815 = torch.prim.ListConstruct %int4_1944, %1754, %int32_1945, %int128_1946 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1816 = torch.aten._unsafe_view %1814, %1815 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1816, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_1947 = torch.constant.int -2 - %1817 = torch.aten.unsqueeze %1713, %int-2_1947 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1817, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_1948 = torch.constant.int 1 - %1818 = torch.aten.size.int %1707, %int1_1948 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_1949 = torch.constant.int 4 - %int8_1950 = torch.constant.int 8 - %int4_1951 = torch.constant.int 4 - %int128_1952 = torch.constant.int 128 - %1819 = torch.prim.ListConstruct %int4_1949, %1818, %int8_1950, %int4_1951, %int128_1952 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1953 = torch.constant.bool false - %1820 = torch.aten.expand %1817, %1819, %false_1953 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1820, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1954 = torch.constant.int 0 - %1821 = torch.aten.clone %1820, %int0_1954 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1821, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_1955 = torch.constant.int 4 - %int32_1956 = torch.constant.int 32 - %int128_1957 = torch.constant.int 128 - %1822 = torch.prim.ListConstruct %int4_1955, %1818, %int32_1956, %int128_1957 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1823 = torch.aten._unsafe_view %1821, %1822 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1823, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_1958 = torch.constant.int 1 - %int2_1959 = torch.constant.int 2 - %1824 = torch.aten.transpose.int %1741, %int1_1958, %int2_1959 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1824, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_1945 = torch.constant.int 1 + %int1_1946 = torch.constant.int 1 + %int1_1947 = torch.constant.int 1 + %1821 = torch.prim.ListConstruct %int4_1944, %int1_1945, %int1_1946, %int1_1947 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1822 = torch.aten.repeat %1820, %1821 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %1822, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_1948 = torch.constant.int 0 + %1823 = torch.aten.unsqueeze %1816, %int0_1948 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1823, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_1949 = torch.constant.int 1 + %int0_1950 = torch.constant.int 0 + %int9223372036854775807_1951 = torch.constant.int 9223372036854775807 + %int1_1952 = torch.constant.int 1 + %1824 = torch.aten.slice.Tensor %1823, %int1_1949, %int0_1950, %int9223372036854775807_1951, %int1_1952 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %1824, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_1953 = torch.constant.int 2 + %1825 = torch.aten.unsqueeze %1824, %int2_1953 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1825, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_1954 = torch.constant.int 3 + %int0_1955 = torch.constant.int 0 + %int9223372036854775807_1956 = torch.constant.int 9223372036854775807 + %int1_1957 = torch.constant.int 1 + %1826 = torch.aten.slice.Tensor %1825, %int3_1954, %int0_1955, %int9223372036854775807_1956, %int1_1957 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %1826, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_1958 = torch.constant.int 4 + %int1_1959 = torch.constant.int 1 %int1_1960 = torch.constant.int 1 - %int2_1961 = torch.constant.int 2 - %1825 = torch.aten.transpose.int %1816, %int1_1960, %int2_1961 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1825, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_1962 = torch.constant.int 1 - %int2_1963 = torch.constant.int 2 - %1826 = torch.aten.transpose.int %1823, %int1_1962, %int2_1963 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1826, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_1964 = torch.constant.float 0.000000e+00 - %true_1965 = torch.constant.bool true - %none_1966 = torch.constant.none - %none_1967 = torch.constant.none - %1827:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1824, %1825, %1826, %float0.000000e00_1964, %true_1965, %none_1966, %none_1967) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %1827#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_1968 = torch.constant.int 1 - %int2_1969 = torch.constant.int 2 - %1828 = torch.aten.transpose.int %1827#0, %int1_1968, %int2_1969 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1828, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_1970 = torch.constant.int 4 - %int4096_1971 = torch.constant.int 4096 - %1829 = torch.prim.ListConstruct %int4_1970, %1726, %int4096_1971 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1830 = torch.aten.view %1828, %1829 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1830, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1972 = torch.constant.int -2 - %int-1_1973 = torch.constant.int -1 - %1831 = torch.aten.transpose.int %68, %int-2_1972, %int-1_1973 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1974 = torch.constant.int 4 - %1832 = torch.aten.mul.int %int4_1974, %1726 : !torch.int, !torch.int -> !torch.int - %int4096_1975 = torch.constant.int 4096 - %1833 = torch.prim.ListConstruct %1832, %int4096_1975 : (!torch.int, !torch.int) -> !torch.list - %1834 = torch.aten.view %1830, %1833 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1834, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1835 = torch.aten.mm %1834, %1831 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1835, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_1976 = torch.constant.int 4 - %int4096_1977 = torch.constant.int 4096 - %1836 = torch.prim.ListConstruct %int4_1976, %1726, %int4096_1977 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1837 = torch.aten.view %1835, %1836 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1837, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_1978 = torch.constant.int 1 - %1838 = torch.aten.add.Tensor %1676, %1837, %int1_1978 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1838, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_1979 = torch.constant.int 6 - %1839 = torch.prims.convert_element_type %1838, %int6_1979 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1839, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_1980 = torch.constant.int 2 - %1840 = torch.aten.pow.Tensor_Scalar %1839, %int2_1980 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1840, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_1981 = torch.constant.int -1 - %1841 = torch.prim.ListConstruct %int-1_1981 : (!torch.int) -> !torch.list - %true_1982 = torch.constant.bool true - %none_1983 = torch.constant.none - %1842 = torch.aten.mean.dim %1840, %1841, %true_1982, %none_1983 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1842, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_1984 = torch.constant.float 9.9999997473787516E-6 + %int1_1961 = torch.constant.int 1 + %1827 = torch.prim.ListConstruct %int4_1958, %int1_1959, %int1_1960, %int1_1961 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1828 = torch.aten.repeat %1826, %1827 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %1828, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %1829 = torch.aten.mul.Tensor %1708, %1822 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1829, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_1962 = torch.constant.int 3 + %int0_1963 = torch.constant.int 0 + %int64_1964 = torch.constant.int 64 + %int1_1965 = torch.constant.int 1 + %1830 = torch.aten.slice.Tensor %1708, %int3_1962, %int0_1963, %int64_1964, %int1_1965 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %1830, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_1966 = torch.constant.int 3 + %int64_1967 = torch.constant.int 64 + %int9223372036854775807_1968 = torch.constant.int 9223372036854775807 + %int1_1969 = torch.constant.int 1 + %1831 = torch.aten.slice.Tensor %1708, %int3_1966, %int64_1967, %int9223372036854775807_1968, %int1_1969 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %1831, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %1832 = torch.aten.neg %1831 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %1832, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %1833 = torch.prim.ListConstruct %1832, %1830 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_1970 = torch.constant.int -1 + %1834 = torch.aten.cat %1833, %int-1_1970 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1834, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %1835 = torch.aten.mul.Tensor %1834, %1828 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1835, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_1971 = torch.constant.int 1 + %1836 = torch.aten.add.Tensor %1829, %1835, %int1_1971 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1836, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_1972 = torch.constant.int 32 + %1837 = torch.aten.mul.Scalar %arg2, %int32_1972 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1837, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int5_1973 = torch.constant.int 5 + %int1_1974 = torch.constant.int 1 + %1838 = torch.aten.add.Scalar %1837, %int5_1973, %int1_1974 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1838, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_1975 = torch.constant.int 2 + %1839 = torch.aten.mul.Scalar %1838, %int2_1975 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1839, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_1976 = torch.constant.int 0 + %int1_1977 = torch.constant.int 1 + %1840 = torch.aten.add.Scalar %1839, %int0_1976, %int1_1977 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1840, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %1841 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %1842 = torch.aten.view %1840, %1841 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %1842, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_1978 = torch.constant.int 4 + %int32_1979 = torch.constant.int 32 + %int8_1980 = torch.constant.int 8 + %int128_1981 = torch.constant.int 128 + %1843 = torch.prim.ListConstruct %int4_1978, %296, %int32_1979, %int8_1980, %int128_1981 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1844 = torch.aten.view %1836, %1843 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1844, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_1982 = torch.constant.int 32 + %int8_1983 = torch.constant.int 8 + %int128_1984 = torch.constant.int 128 + %1845 = torch.prim.ListConstruct %504, %int32_1982, %int8_1983, %int128_1984 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1846 = torch.aten.view %1844, %1845 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %1846, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> %int1_1985 = torch.constant.int 1 - %1843 = torch.aten.add.Scalar %1842, %float9.999990e-06_1984, %int1_1985 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1843, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1844 = torch.aten.rsqrt %1843 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1844, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1845 = torch.aten.mul.Tensor %1839, %1844 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1845, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_1986 = torch.constant.int 5 - %1846 = torch.prims.convert_element_type %1845, %int5_1986 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1846, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %1847 = torch.aten.mul.Tensor %69, %1846 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1847, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_1986 = torch.constant.int 2 + %1847 = torch.aten.transpose.int %1846, %int1_1985, %int2_1986 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1847, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> %int5_1987 = torch.constant.int 5 - %1848 = torch.prims.convert_element_type %1847, %int5_1987 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1848, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_1988 = torch.constant.int -2 - %int-1_1989 = torch.constant.int -1 - %1849 = torch.aten.transpose.int %70, %int-2_1988, %int-1_1989 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1990 = torch.constant.int 4 - %1850 = torch.aten.mul.int %int4_1990, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1991 = torch.constant.int 4096 - %1851 = torch.prim.ListConstruct %1850, %int4096_1991 : (!torch.int, !torch.int) -> !torch.list - %1852 = torch.aten.view %1848, %1851 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1852, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1853 = torch.aten.mm %1852, %1849 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1853, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_1992 = torch.constant.int 4 - %int14336_1993 = torch.constant.int 14336 - %1854 = torch.prim.ListConstruct %int4_1992, %306, %int14336_1993 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1855 = torch.aten.view %1853, %1854 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1855, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %1856 = torch.aten.silu %1855 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1856, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_1994 = torch.constant.int -2 - %int-1_1995 = torch.constant.int -1 - %1857 = torch.aten.transpose.int %71, %int-2_1994, %int-1_1995 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1996 = torch.constant.int 4 - %1858 = torch.aten.mul.int %int4_1996, %306 : !torch.int, !torch.int -> !torch.int - %int4096_1997 = torch.constant.int 4096 - %1859 = torch.prim.ListConstruct %1858, %int4096_1997 : (!torch.int, !torch.int) -> !torch.list - %1860 = torch.aten.view %1848, %1859 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1860, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1861 = torch.aten.mm %1860, %1857 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1861, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_1998 = torch.constant.int 4 - %int14336_1999 = torch.constant.int 14336 - %1862 = torch.prim.ListConstruct %int4_1998, %306, %int14336_1999 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1863 = torch.aten.view %1861, %1862 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1863, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %1864 = torch.aten.mul.Tensor %1856, %1863 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %1864, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_2000 = torch.constant.int -2 - %int-1_2001 = torch.constant.int -1 - %1865 = torch.aten.transpose.int %72, %int-2_2000, %int-1_2001 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_2002 = torch.constant.int 1 - %1866 = torch.aten.size.int %1855, %int1_2002 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_2003 = torch.constant.int 4 - %1867 = torch.aten.mul.int %int4_2003, %1866 : !torch.int, !torch.int -> !torch.int - %int14336_2004 = torch.constant.int 14336 - %1868 = torch.prim.ListConstruct %1867, %int14336_2004 : (!torch.int, !torch.int) -> !torch.list - %1869 = torch.aten.view %1864, %1868 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %1869, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %1870 = torch.aten.mm %1869, %1865 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1870, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_2005 = torch.constant.int 4 - %int4096_2006 = torch.constant.int 4096 - %1871 = torch.prim.ListConstruct %int4_2005, %1866, %int4096_2006 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1872 = torch.aten.view %1870, %1871 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1872, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_2007 = torch.constant.int 1 - %1873 = torch.aten.add.Tensor %1838, %1872, %int1_2007 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1873, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_2008 = torch.constant.int 6 - %1874 = torch.prims.convert_element_type %1873, %int6_2008 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1874, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_2009 = torch.constant.int 2 - %1875 = torch.aten.pow.Tensor_Scalar %1874, %int2_2009 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1875, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_2010 = torch.constant.int -1 - %1876 = torch.prim.ListConstruct %int-1_2010 : (!torch.int) -> !torch.list - %true_2011 = torch.constant.bool true - %none_2012 = torch.constant.none - %1877 = torch.aten.mean.dim %1875, %1876, %true_2011, %none_2012 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1877, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_2013 = torch.constant.float 9.9999997473787516E-6 - %int1_2014 = torch.constant.int 1 - %1878 = torch.aten.add.Scalar %1877, %float9.999990e-06_2013, %int1_2014 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1878, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1879 = torch.aten.rsqrt %1878 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %1879, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %1880 = torch.aten.mul.Tensor %1874, %1879 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1880, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2015 = torch.constant.int 5 - %1881 = torch.prims.convert_element_type %1880, %int5_2015 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1881, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %1882 = torch.aten.mul.Tensor %73, %1881 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %1882, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2016 = torch.constant.int 5 - %1883 = torch.prims.convert_element_type %1882, %int5_2016 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1883, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_2017 = torch.constant.int -2 - %int-1_2018 = torch.constant.int -1 - %1884 = torch.aten.transpose.int %74, %int-2_2017, %int-1_2018 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_2019 = torch.constant.int 4 - %1885 = torch.aten.mul.int %int4_2019, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2020 = torch.constant.int 4096 - %1886 = torch.prim.ListConstruct %1885, %int4096_2020 : (!torch.int, !torch.int) -> !torch.list - %1887 = torch.aten.view %1883, %1886 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1887, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1888 = torch.aten.mm %1887, %1884 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1888, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_2021 = torch.constant.int 4 - %int4096_2022 = torch.constant.int 4096 - %1889 = torch.prim.ListConstruct %int4_2021, %306, %int4096_2022 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1890 = torch.aten.view %1888, %1889 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %1890, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_2023 = torch.constant.int -2 - %int-1_2024 = torch.constant.int -1 - %1891 = torch.aten.transpose.int %75, %int-2_2023, %int-1_2024 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_2025 = torch.constant.int 4 - %1892 = torch.aten.mul.int %int4_2025, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2026 = torch.constant.int 4096 - %1893 = torch.prim.ListConstruct %1892, %int4096_2026 : (!torch.int, !torch.int) -> !torch.list - %1894 = torch.aten.view %1883, %1893 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1894, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1895 = torch.aten.mm %1894, %1891 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %1895, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_2027 = torch.constant.int 4 - %int1024_2028 = torch.constant.int 1024 - %1896 = torch.prim.ListConstruct %int4_2027, %306, %int1024_2028 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1897 = torch.aten.view %1895, %1896 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %1897, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_2029 = torch.constant.int -2 - %int-1_2030 = torch.constant.int -1 - %1898 = torch.aten.transpose.int %76, %int-2_2029, %int-1_2030 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_2031 = torch.constant.int 4 - %1899 = torch.aten.mul.int %int4_2031, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2032 = torch.constant.int 4096 - %1900 = torch.prim.ListConstruct %1899, %int4096_2032 : (!torch.int, !torch.int) -> !torch.list - %1901 = torch.aten.view %1883, %1900 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %1901, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %1902 = torch.aten.mm %1901, %1898 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %1902, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_2033 = torch.constant.int 4 - %int1024_2034 = torch.constant.int 1024 - %1903 = torch.prim.ListConstruct %int4_2033, %306, %int1024_2034 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1904 = torch.aten.view %1902, %1903 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %1904, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %1848 = torch.prims.convert_element_type %1847, %int5_1987 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1848, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_1988 = torch.constant.int 32 + %int2_1989 = torch.constant.int 2 + %int8_1990 = torch.constant.int 8 + %int32_1991 = torch.constant.int 32 + %int128_1992 = torch.constant.int 128 + %1849 = torch.prim.ListConstruct %297, %int32_1988, %int2_1989, %int8_1990, %int32_1991, %int128_1992 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1850 = torch.aten.view %1612, %1849 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1850, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_1993 = torch.constant.int 8 + %int32_1994 = torch.constant.int 32 + %int128_1995 = torch.constant.int 128 + %1851 = torch.prim.ListConstruct %497, %int8_1993, %int32_1994, %int128_1995 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1852 = torch.aten.view %1850, %1851 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1852, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %1853 = torch.prim.ListConstruct %1842 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_1996 = torch.constant.bool false + %1854 = torch.aten.index_put %1852, %1853, %1848, %false_1996 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1854, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_1997 = torch.constant.int 32 + %int2_1998 = torch.constant.int 2 + %int8_1999 = torch.constant.int 8 + %int32_2000 = torch.constant.int 32 + %int128_2001 = torch.constant.int 128 + %1855 = torch.prim.ListConstruct %297, %int32_1997, %int2_1998, %int8_1999, %int32_2000, %int128_2001 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1856 = torch.aten.view %1854, %1855 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1856, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_2002 = torch.constant.int 2097152 + %1857 = torch.prim.ListConstruct %297, %int2097152_2002 : (!torch.int, !torch.int) -> !torch.list + %1858 = torch.aten.view %1856, %1857 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1858, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_2003 = torch.constant.int 32 + %int2_2004 = torch.constant.int 2 + %int8_2005 = torch.constant.int 8 + %int32_2006 = torch.constant.int 32 + %int128_2007 = torch.constant.int 128 + %1859 = torch.prim.ListConstruct %297, %int32_2003, %int2_2004, %int8_2005, %int32_2006, %int128_2007 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1860 = torch.aten.view %1858, %1859 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1860, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_2008 = torch.constant.int 8 + %int32_2009 = torch.constant.int 32 + %int128_2010 = torch.constant.int 128 + %1861 = torch.prim.ListConstruct %497, %int8_2008, %int32_2009, %int128_2010 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1862 = torch.aten.view %1860, %1861 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1862, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_2011 = torch.constant.int 32 + %1863 = torch.aten.mul.Scalar %arg2, %int32_2011 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1863, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int5_2012 = torch.constant.int 5 + %int1_2013 = torch.constant.int 1 + %1864 = torch.aten.add.Scalar %1863, %int5_2012, %int1_2013 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1864, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_2014 = torch.constant.int 2 + %1865 = torch.aten.mul.Scalar %1864, %int2_2014 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1865, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_2015 = torch.constant.int 1 + %int1_2016 = torch.constant.int 1 + %1866 = torch.aten.add.Scalar %1865, %int1_2015, %int1_2016 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %1866, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %1867 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %1868 = torch.aten.view %1866, %1867 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %1868, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_2017 = torch.constant.int 4 + %int32_2018 = torch.constant.int 32 + %int8_2019 = torch.constant.int 8 + %int128_2020 = torch.constant.int 128 + %1869 = torch.prim.ListConstruct %int4_2017, %296, %int32_2018, %int8_2019, %int128_2020 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1870 = torch.aten.view %1710, %1869 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1870, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_2021 = torch.constant.int 32 + %int8_2022 = torch.constant.int 8 + %int128_2023 = torch.constant.int 128 + %1871 = torch.prim.ListConstruct %504, %int32_2021, %int8_2022, %int128_2023 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1872 = torch.aten.view %1870, %1871 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %1872, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_2024 = torch.constant.int 1 + %int2_2025 = torch.constant.int 2 + %1873 = torch.aten.transpose.int %1872, %int1_2024, %int2_2025 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1873, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_2026 = torch.constant.int 5 + %1874 = torch.prims.convert_element_type %1873, %int5_2026 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1874, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %1875 = torch.prim.ListConstruct %1868 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_2027 = torch.constant.bool false + %1876 = torch.aten.index_put %1862, %1875, %1874, %false_2027 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %1876, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_2028 = torch.constant.int 32 + %int2_2029 = torch.constant.int 2 + %int8_2030 = torch.constant.int 8 + %int32_2031 = torch.constant.int 32 + %int128_2032 = torch.constant.int 128 + %1877 = torch.prim.ListConstruct %297, %int32_2028, %int2_2029, %int8_2030, %int32_2031, %int128_2032 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1878 = torch.aten.view %1876, %1877 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1878, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_2033 = torch.constant.int 2097152 + %1879 = torch.prim.ListConstruct %297, %int2097152_2033 : (!torch.int, !torch.int) -> !torch.list + %1880 = torch.aten.view %1878, %1879 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1880, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_2034 = torch.constant.int -2 + %1881 = torch.aten.unsqueeze %1836, %int-2_2034 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1881, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> %int4_2035 = torch.constant.int 4 - %int32_2036 = torch.constant.int 32 - %int128_2037 = torch.constant.int 128 - %1905 = torch.prim.ListConstruct %int4_2035, %306, %int32_2036, %int128_2037 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1906 = torch.aten.view %1890, %1905 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1906, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_2038 = torch.constant.int 4 - %int8_2039 = torch.constant.int 8 - %int128_2040 = torch.constant.int 128 - %1907 = torch.prim.ListConstruct %int4_2038, %306, %int8_2039, %int128_2040 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1908 = torch.aten.view %1897, %1907 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1908, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int8_2036 = torch.constant.int 8 + %int4_2037 = torch.constant.int 4 + %int128_2038 = torch.constant.int 128 + %1882 = torch.prim.ListConstruct %int4_2035, %298, %int8_2036, %int4_2037, %int128_2038 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_2039 = torch.constant.bool false + %1883 = torch.aten.expand %1881, %1882, %false_2039 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1883, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_2040 = torch.constant.int 0 + %1884 = torch.aten.clone %1883, %int0_2040 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1884, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_2041 = torch.constant.int 4 - %int8_2042 = torch.constant.int 8 + %int32_2042 = torch.constant.int 32 %int128_2043 = torch.constant.int 128 - %1909 = torch.prim.ListConstruct %int4_2041, %306, %int8_2042, %int128_2043 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1910 = torch.aten.view %1904, %1909 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1910, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_2044 = torch.constant.int 131072 - %none_2045 = torch.constant.none - %none_2046 = torch.constant.none - %cpu_2047 = torch.constant.device "cpu" - %false_2048 = torch.constant.bool false - %1911 = torch.aten.arange %int131072_2044, %none_2045, %none_2046, %cpu_2047, %false_2048 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_2049 = torch.constant.int 0 - %int128_2050 = torch.constant.int 128 - %none_2051 = torch.constant.none - %none_2052 = torch.constant.none - %cpu_2053 = torch.constant.device "cpu" - %false_2054 = torch.constant.bool false - %1912 = torch.aten.arange.start %int0_2049, %int128_2050, %none_2051, %none_2052, %cpu_2053, %false_2054 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> + %1885 = torch.prim.ListConstruct %int4_2041, %298, %int32_2042, %int128_2043 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1886 = torch.aten._unsafe_view %1884, %1885 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1886, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_2044 = torch.constant.int -2 + %1887 = torch.aten.unsqueeze %1710, %int-2_2044 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1887, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_2045 = torch.constant.int 4 + %int8_2046 = torch.constant.int 8 + %int4_2047 = torch.constant.int 4 + %int128_2048 = torch.constant.int 128 + %1888 = torch.prim.ListConstruct %int4_2045, %298, %int8_2046, %int4_2047, %int128_2048 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_2049 = torch.constant.bool false + %1889 = torch.aten.expand %1887, %1888, %false_2049 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1889, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_2050 = torch.constant.int 0 + %1890 = torch.aten.clone %1889, %int0_2050 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1890, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_2051 = torch.constant.int 4 + %int32_2052 = torch.constant.int 32 + %int128_2053 = torch.constant.int 128 + %1891 = torch.prim.ListConstruct %int4_2051, %298, %int32_2052, %int128_2053 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1892 = torch.aten._unsafe_view %1890, %1891 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1892, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_2054 = torch.constant.int 1 %int2_2055 = torch.constant.int 2 - %1913 = torch.aten.floor_divide.Scalar %1912, %int2_2055 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_2056 = torch.constant.int 6 - %1914 = torch.prims.convert_element_type %1913, %int6_2056 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_2057 = torch.constant.int 128 - %1915 = torch.aten.div.Scalar %1914, %int128_2057 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_2058 = torch.constant.float 2.000000e+00 - %1916 = torch.aten.mul.Scalar %1915, %float2.000000e00_2058 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_2059 = torch.constant.float 5.000000e+05 - %1917 = torch.aten.pow.Scalar %float5.000000e05_2059, %1916 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %1918 = torch.aten.reciprocal %1917 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_2060 = torch.constant.float 1.000000e+00 - %1919 = torch.aten.mul.Scalar %1918, %float1.000000e00_2060 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_2061 = torch.constant.int 1 - %1920 = torch.aten.unsqueeze %1911, %int1_2061 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_2062 = torch.constant.int 0 - %1921 = torch.aten.unsqueeze %1919, %int0_2062 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %1922 = torch.aten.mul.Tensor %1920, %1921 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %1893 = torch.aten.transpose.int %1773, %int1_2054, %int2_2055 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1893, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_2056 = torch.constant.int 1 + %int2_2057 = torch.constant.int 2 + %1894 = torch.aten.transpose.int %1886, %int1_2056, %int2_2057 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1894, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_2058 = torch.constant.int 1 + %int2_2059 = torch.constant.int 2 + %1895 = torch.aten.transpose.int %1892, %int1_2058, %int2_2059 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1895, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_2060 = torch.constant.float 0.000000e+00 + %false_2061 = torch.constant.bool false + %none_2062 = torch.constant.none + %1896:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1893, %1894, %1895, %float0.000000e00_2060, %false_2061, %327, %none_2062) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %1896#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_2063 = torch.constant.int 1 - %1923 = torch.aten.size.int %1890, %int1_2063 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_2064 = torch.constant.int 0 - %1924 = torch.aten.add.int %int0_2064, %1923 : !torch.int, !torch.int -> !torch.int - %int0_2065 = torch.constant.int 0 - %int0_2066 = torch.constant.int 0 - %int1_2067 = torch.constant.int 1 - %1925 = torch.aten.slice.Tensor %1922, %int0_2065, %int0_2066, %1924, %int1_2067 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1925, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_2068 = torch.constant.int 1 - %int0_2069 = torch.constant.int 0 - %int9223372036854775807_2070 = torch.constant.int 9223372036854775807 - %int1_2071 = torch.constant.int 1 - %1926 = torch.aten.slice.Tensor %1925, %int1_2068, %int0_2069, %int9223372036854775807_2070, %int1_2071 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1926, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_2072 = torch.constant.int 1 - %int0_2073 = torch.constant.int 0 - %int9223372036854775807_2074 = torch.constant.int 9223372036854775807 - %int1_2075 = torch.constant.int 1 - %1927 = torch.aten.slice.Tensor %1926, %int1_2072, %int0_2073, %int9223372036854775807_2074, %int1_2075 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1927, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_2076 = torch.constant.int 0 - %1928 = torch.aten.unsqueeze %1927, %int0_2076 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1928, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_2077 = torch.constant.int 1 - %int0_2078 = torch.constant.int 0 - %int9223372036854775807_2079 = torch.constant.int 9223372036854775807 + %int2_2064 = torch.constant.int 2 + %1897 = torch.aten.transpose.int %1896#0, %int1_2063, %int2_2064 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1897, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_2065 = torch.constant.int 4 + %int4096_2066 = torch.constant.int 4096 + %1898 = torch.prim.ListConstruct %int4_2065, %298, %int4096_2066 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1899 = torch.aten.view %1897, %1898 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1899, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_2067 = torch.constant.int -2 + %int-1_2068 = torch.constant.int -1 + %1900 = torch.aten.transpose.int %51, %int-2_2067, %int-1_2068 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2069 = torch.constant.int 5 + %1901 = torch.prims.convert_element_type %1900, %int5_2069 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_2070 = torch.constant.int 4096 + %1902 = torch.prim.ListConstruct %342, %int4096_2070 : (!torch.int, !torch.int) -> !torch.list + %1903 = torch.aten.view %1899, %1902 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1903, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1904 = torch.aten.mm %1903, %1901 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1904, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_2071 = torch.constant.int 4 + %int4096_2072 = torch.constant.int 4096 + %1905 = torch.prim.ListConstruct %int4_2071, %298, %int4096_2072 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1906 = torch.aten.view %1904, %1905 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1906, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_2073 = torch.constant.int 1 + %1907 = torch.aten.add.Tensor %1673, %1906, %int1_2073 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1907, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_2074 = torch.constant.int 6 + %1908 = torch.prims.convert_element_type %1907, %int6_2074 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1908, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_2075 = torch.constant.int 2 + %1909 = torch.aten.pow.Tensor_Scalar %1908, %int2_2075 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1909, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_2076 = torch.constant.int -1 + %1910 = torch.prim.ListConstruct %int-1_2076 : (!torch.int) -> !torch.list + %true_2077 = torch.constant.bool true + %none_2078 = torch.constant.none + %1911 = torch.aten.mean.dim %1909, %1910, %true_2077, %none_2078 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1911, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_2079 = torch.constant.float 9.9999997473787516E-6 %int1_2080 = torch.constant.int 1 - %1929 = torch.aten.slice.Tensor %1928, %int1_2077, %int0_2078, %int9223372036854775807_2079, %int1_2080 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1929, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_2081 = torch.constant.int 2 - %int0_2082 = torch.constant.int 0 - %int9223372036854775807_2083 = torch.constant.int 9223372036854775807 - %int1_2084 = torch.constant.int 1 - %1930 = torch.aten.slice.Tensor %1929, %int2_2081, %int0_2082, %int9223372036854775807_2083, %int1_2084 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1930, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_2085 = torch.constant.int 4 - %int1_2086 = torch.constant.int 1 - %int1_2087 = torch.constant.int 1 - %1931 = torch.prim.ListConstruct %int4_2085, %int1_2086, %int1_2087 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1932 = torch.aten.repeat %1930, %1931 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %1932, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_2088 = torch.constant.int 6 - %1933 = torch.prims.convert_element_type %1906, %int6_2088 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %1933, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %1934 = torch_c.to_builtin_tensor %1933 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %1935 = torch_c.to_builtin_tensor %1932 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %1936 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%1934, %1935) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %1937 = torch_c.from_builtin_tensor %1936 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %1937, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_2089 = torch.constant.int 5 - %1938 = torch.prims.convert_element_type %1937, %int5_2089 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1938, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_2090 = torch.constant.int 131072 - %none_2091 = torch.constant.none - %none_2092 = torch.constant.none - %cpu_2093 = torch.constant.device "cpu" - %false_2094 = torch.constant.bool false - %1939 = torch.aten.arange %int131072_2090, %none_2091, %none_2092, %cpu_2093, %false_2094 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_2095 = torch.constant.int 0 - %int128_2096 = torch.constant.int 128 - %none_2097 = torch.constant.none - %none_2098 = torch.constant.none - %cpu_2099 = torch.constant.device "cpu" - %false_2100 = torch.constant.bool false - %1940 = torch.aten.arange.start %int0_2095, %int128_2096, %none_2097, %none_2098, %cpu_2099, %false_2100 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_2101 = torch.constant.int 2 - %1941 = torch.aten.floor_divide.Scalar %1940, %int2_2101 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> + %1912 = torch.aten.add.Scalar %1911, %float9.999990e-06_2079, %int1_2080 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1912, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1913 = torch.aten.rsqrt %1912 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1913, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1914 = torch.aten.mul.Tensor %1908, %1913 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1914, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_2081 = torch.constant.int 5 + %1915 = torch.prims.convert_element_type %1914, %int5_2081 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1915, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %1916 = torch.aten.mul.Tensor %52, %1915 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1916, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_2082 = torch.constant.int 5 + %1917 = torch.prims.convert_element_type %1916, %int5_2082 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1917, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_2083 = torch.constant.int -2 + %int-1_2084 = torch.constant.int -1 + %1918 = torch.aten.transpose.int %53, %int-2_2083, %int-1_2084 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_2085 = torch.constant.int 5 + %1919 = torch.prims.convert_element_type %1918, %int5_2085 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_2086 = torch.constant.int 4096 + %1920 = torch.prim.ListConstruct %342, %int4096_2086 : (!torch.int, !torch.int) -> !torch.list + %1921 = torch.aten.view %1917, %1920 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1921, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1922 = torch.aten.mm %1921, %1919 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %1922, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_2087 = torch.constant.int 4 + %int14336_2088 = torch.constant.int 14336 + %1923 = torch.prim.ListConstruct %int4_2087, %298, %int14336_2088 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1924 = torch.aten.view %1922, %1923 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1924, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %1925 = torch.aten.silu %1924 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1925, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_2089 = torch.constant.int -2 + %int-1_2090 = torch.constant.int -1 + %1926 = torch.aten.transpose.int %54, %int-2_2089, %int-1_2090 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_2091 = torch.constant.int 5 + %1927 = torch.prims.convert_element_type %1926, %int5_2091 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_2092 = torch.constant.int 4096 + %1928 = torch.prim.ListConstruct %342, %int4096_2092 : (!torch.int, !torch.int) -> !torch.list + %1929 = torch.aten.view %1917, %1928 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1929, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1930 = torch.aten.mm %1929, %1927 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %1930, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_2093 = torch.constant.int 4 + %int14336_2094 = torch.constant.int 14336 + %1931 = torch.prim.ListConstruct %int4_2093, %298, %int14336_2094 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1932 = torch.aten.view %1930, %1931 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1932, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %1933 = torch.aten.mul.Tensor %1925, %1932 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %1933, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_2095 = torch.constant.int -2 + %int-1_2096 = torch.constant.int -1 + %1934 = torch.aten.transpose.int %55, %int-2_2095, %int-1_2096 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_2097 = torch.constant.int 5 + %1935 = torch.prims.convert_element_type %1934, %int5_2097 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_2098 = torch.constant.int 14336 + %1936 = torch.prim.ListConstruct %342, %int14336_2098 : (!torch.int, !torch.int) -> !torch.list + %1937 = torch.aten.view %1933, %1936 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %1937, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %1938 = torch.aten.mm %1937, %1935 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1938, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_2099 = torch.constant.int 4 + %int4096_2100 = torch.constant.int 4096 + %1939 = torch.prim.ListConstruct %int4_2099, %298, %int4096_2100 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1940 = torch.aten.view %1938, %1939 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1940, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_2101 = torch.constant.int 1 + %1941 = torch.aten.add.Tensor %1907, %1940, %int1_2101 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1941, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> %int6_2102 = torch.constant.int 6 - %1942 = torch.prims.convert_element_type %1941, %int6_2102 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_2103 = torch.constant.int 128 - %1943 = torch.aten.div.Scalar %1942, %int128_2103 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_2104 = torch.constant.float 2.000000e+00 - %1944 = torch.aten.mul.Scalar %1943, %float2.000000e00_2104 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_2105 = torch.constant.float 5.000000e+05 - %1945 = torch.aten.pow.Scalar %float5.000000e05_2105, %1944 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %1946 = torch.aten.reciprocal %1945 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_2106 = torch.constant.float 1.000000e+00 - %1947 = torch.aten.mul.Scalar %1946, %float1.000000e00_2106 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_2107 = torch.constant.int 1 - %1948 = torch.aten.unsqueeze %1939, %int1_2107 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_2108 = torch.constant.int 0 - %1949 = torch.aten.unsqueeze %1947, %int0_2108 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %1950 = torch.aten.mul.Tensor %1948, %1949 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_2109 = torch.constant.int 1 - %1951 = torch.aten.size.int %1897, %int1_2109 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_2110 = torch.constant.int 0 - %1952 = torch.aten.add.int %int0_2110, %1951 : !torch.int, !torch.int -> !torch.int - %int0_2111 = torch.constant.int 0 - %int0_2112 = torch.constant.int 0 - %int1_2113 = torch.constant.int 1 - %1953 = torch.aten.slice.Tensor %1950, %int0_2111, %int0_2112, %1952, %int1_2113 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1953, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_2114 = torch.constant.int 1 - %int0_2115 = torch.constant.int 0 - %int9223372036854775807_2116 = torch.constant.int 9223372036854775807 - %int1_2117 = torch.constant.int 1 - %1954 = torch.aten.slice.Tensor %1953, %int1_2114, %int0_2115, %int9223372036854775807_2116, %int1_2117 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1954, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_2118 = torch.constant.int 1 - %int0_2119 = torch.constant.int 0 - %int9223372036854775807_2120 = torch.constant.int 9223372036854775807 - %int1_2121 = torch.constant.int 1 - %1955 = torch.aten.slice.Tensor %1954, %int1_2118, %int0_2119, %int9223372036854775807_2120, %int1_2121 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %1955, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_2122 = torch.constant.int 0 - %1956 = torch.aten.unsqueeze %1955, %int0_2122 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1956, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_2123 = torch.constant.int 1 - %int0_2124 = torch.constant.int 0 - %int9223372036854775807_2125 = torch.constant.int 9223372036854775807 - %int1_2126 = torch.constant.int 1 - %1957 = torch.aten.slice.Tensor %1956, %int1_2123, %int0_2124, %int9223372036854775807_2125, %int1_2126 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1957, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_2127 = torch.constant.int 2 - %int0_2128 = torch.constant.int 0 - %int9223372036854775807_2129 = torch.constant.int 9223372036854775807 - %int1_2130 = torch.constant.int 1 - %1958 = torch.aten.slice.Tensor %1957, %int2_2127, %int0_2128, %int9223372036854775807_2129, %int1_2130 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %1958, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_2131 = torch.constant.int 4 - %int1_2132 = torch.constant.int 1 - %int1_2133 = torch.constant.int 1 - %1959 = torch.prim.ListConstruct %int4_2131, %int1_2132, %int1_2133 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1960 = torch.aten.repeat %1958, %1959 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %1960, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_2134 = torch.constant.int 6 - %1961 = torch.prims.convert_element_type %1908, %int6_2134 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %1961, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %1962 = torch_c.to_builtin_tensor %1961 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %1963 = torch_c.to_builtin_tensor %1960 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %1964 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%1962, %1963) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %1965 = torch_c.from_builtin_tensor %1964 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %1965, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_2135 = torch.constant.int 5 - %1966 = torch.prims.convert_element_type %1965, %int5_2135 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1966, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_2136 = torch.constant.int 64 - %1967 = torch.aten.mul.Scalar %arg2, %int64_2136 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1967, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int16 = torch.constant.int 16 - %int1_2137 = torch.constant.int 1 - %1968 = torch.aten.add.Scalar %1967, %int16, %int1_2137 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1968, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_2138 = torch.constant.int 4 - %int32_2139 = torch.constant.int 32 - %int8_2140 = torch.constant.int 8 - %int128_2141 = torch.constant.int 128 - %1969 = torch.prim.ListConstruct %int4_2138, %398, %int32_2139, %int8_2140, %int128_2141 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1970 = torch.aten.view %1966, %1969 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1970, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_2142 = torch.constant.int 4 - %1971 = torch.aten.mul.int %int4_2142, %398 : !torch.int, !torch.int -> !torch.int - %int32_2143 = torch.constant.int 32 - %int8_2144 = torch.constant.int 8 - %int128_2145 = torch.constant.int 128 - %1972 = torch.prim.ListConstruct %1971, %int32_2143, %int8_2144, %int128_2145 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1973 = torch.aten.view %1970, %1972 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1973, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %1942 = torch.prims.convert_element_type %1941, %int6_2102 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1942, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_2103 = torch.constant.int 2 + %1943 = torch.aten.pow.Tensor_Scalar %1942, %int2_2103 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1943, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_2104 = torch.constant.int -1 + %1944 = torch.prim.ListConstruct %int-1_2104 : (!torch.int) -> !torch.list + %true_2105 = torch.constant.bool true + %none_2106 = torch.constant.none + %1945 = torch.aten.mean.dim %1943, %1944, %true_2105, %none_2106 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1945, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_2107 = torch.constant.float 9.9999997473787516E-6 + %int1_2108 = torch.constant.int 1 + %1946 = torch.aten.add.Scalar %1945, %float9.999990e-06_2107, %int1_2108 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1946, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1947 = torch.aten.rsqrt %1946 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %1947, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %1948 = torch.aten.mul.Tensor %1942, %1947 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1948, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_2109 = torch.constant.int 5 + %1949 = torch.prims.convert_element_type %1948, %int5_2109 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1949, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %1950 = torch.aten.mul.Tensor %56, %1949 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %1950, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_2110 = torch.constant.int 5 + %1951 = torch.prims.convert_element_type %1950, %int5_2110 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1951, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_2111 = torch.constant.int -2 + %int-1_2112 = torch.constant.int -1 + %1952 = torch.aten.transpose.int %57, %int-2_2111, %int-1_2112 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2113 = torch.constant.int 5 + %1953 = torch.prims.convert_element_type %1952, %int5_2113 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_2114 = torch.constant.int 4096 + %1954 = torch.prim.ListConstruct %342, %int4096_2114 : (!torch.int, !torch.int) -> !torch.list + %1955 = torch.aten.view %1951, %1954 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1955, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1956 = torch.aten.mm %1955, %1953 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1956, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_2115 = torch.constant.int 4 + %int4096_2116 = torch.constant.int 4096 + %1957 = torch.prim.ListConstruct %int4_2115, %298, %int4096_2116 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1958 = torch.aten.view %1956, %1957 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %1958, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_2117 = torch.constant.int -2 + %int-1_2118 = torch.constant.int -1 + %1959 = torch.aten.transpose.int %58, %int-2_2117, %int-1_2118 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_2119 = torch.constant.int 5 + %1960 = torch.prims.convert_element_type %1959, %int5_2119 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_2120 = torch.constant.int 4096 + %1961 = torch.prim.ListConstruct %342, %int4096_2120 : (!torch.int, !torch.int) -> !torch.list + %1962 = torch.aten.view %1951, %1961 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1962, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1963 = torch.aten.mm %1962, %1960 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %1963, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_2121 = torch.constant.int 4 + %int1024_2122 = torch.constant.int 1024 + %1964 = torch.prim.ListConstruct %int4_2121, %298, %int1024_2122 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1965 = torch.aten.view %1963, %1964 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %1965, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_2123 = torch.constant.int -2 + %int-1_2124 = torch.constant.int -1 + %1966 = torch.aten.transpose.int %59, %int-2_2123, %int-1_2124 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_2125 = torch.constant.int 5 + %1967 = torch.prims.convert_element_type %1966, %int5_2125 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_2126 = torch.constant.int 4096 + %1968 = torch.prim.ListConstruct %342, %int4096_2126 : (!torch.int, !torch.int) -> !torch.list + %1969 = torch.aten.view %1951, %1968 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %1969, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %1970 = torch.aten.mm %1969, %1967 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %1970, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_2127 = torch.constant.int 4 + %int1024_2128 = torch.constant.int 1024 + %1971 = torch.prim.ListConstruct %int4_2127, %298, %int1024_2128 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1972 = torch.aten.view %1970, %1971 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %1972, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_2129 = torch.constant.int 4 + %int32_2130 = torch.constant.int 32 + %int128_2131 = torch.constant.int 128 + %1973 = torch.prim.ListConstruct %int4_2129, %298, %int32_2130, %int128_2131 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1974 = torch.aten.view %1958, %1973 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1974, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_2132 = torch.constant.int 4 + %int8_2133 = torch.constant.int 8 + %int128_2134 = torch.constant.int 128 + %1975 = torch.prim.ListConstruct %int4_2132, %298, %int8_2133, %int128_2134 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1976 = torch.aten.view %1965, %1975 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1976, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_2135 = torch.constant.int 4 + %int8_2136 = torch.constant.int 8 + %int128_2137 = torch.constant.int 128 + %1977 = torch.prim.ListConstruct %int4_2135, %298, %int8_2136, %int128_2137 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1978 = torch.aten.view %1972, %1977 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1978, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_2138 = torch.constant.int 131072 + %none_2139 = torch.constant.none + %none_2140 = torch.constant.none + %cpu_2141 = torch.constant.device "cpu" + %false_2142 = torch.constant.bool false + %1979 = torch.aten.arange %int131072_2138, %none_2139, %none_2140, %cpu_2141, %false_2142 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_2143 = torch.constant.int 0 + %int128_2144 = torch.constant.int 128 + %int2_2145 = torch.constant.int 2 %int4_2146 = torch.constant.int 4 - %1974 = torch.aten.mul.int %int4_2146, %398 : !torch.int, !torch.int -> !torch.int - %1975 = torch.prim.ListConstruct %1974 : (!torch.int) -> !torch.list - %1976 = torch.aten.view %1968, %1975 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1976, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_2147 = torch.constant.int 32 - %int2_2148 = torch.constant.int 2 - %int32_2149 = torch.constant.int 32 - %int8_2150 = torch.constant.int 8 + %none_2147 = torch.constant.none + %cpu_2148 = torch.constant.device "cpu" + %false_2149 = torch.constant.bool false + %1980 = torch.aten.arange.start_step %int0_2143, %int128_2144, %int2_2145, %int4_2146, %none_2147, %cpu_2148, %false_2149 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_2150 = torch.constant.int 6 + %1981 = torch.prims.convert_element_type %1980, %int6_2150 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> %int128_2151 = torch.constant.int 128 - %1977 = torch.prim.ListConstruct %389, %int32_2147, %int2_2148, %int32_2149, %int8_2150, %int128_2151 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1978 = torch.aten.view %1810, %1977 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1978, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2152 = torch.constant.int 32 - %1979 = torch.aten.mul.int %389, %int32_2152 : !torch.int, !torch.int -> !torch.int - %int2_2153 = torch.constant.int 2 - %1980 = torch.aten.mul.int %1979, %int2_2153 : !torch.int, !torch.int -> !torch.int - %int32_2154 = torch.constant.int 32 - %int8_2155 = torch.constant.int 8 - %int128_2156 = torch.constant.int 128 - %1981 = torch.prim.ListConstruct %1980, %int32_2154, %int8_2155, %int128_2156 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1982 = torch.aten.view %1978, %1981 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1982, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %1983 = torch.prim.ListConstruct %1976 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_2157 = torch.constant.bool false - %1984 = torch.aten.index_put %1982, %1983, %1973, %false_2157 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1984, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_2158 = torch.constant.int 32 - %int2_2159 = torch.constant.int 2 - %int32_2160 = torch.constant.int 32 - %int8_2161 = torch.constant.int 8 - %int128_2162 = torch.constant.int 128 - %1985 = torch.prim.ListConstruct %389, %int32_2158, %int2_2159, %int32_2160, %int8_2161, %int128_2162 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1986 = torch.aten.view %1984, %1985 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1986, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2163 = torch.constant.int 2097152 - %1987 = torch.prim.ListConstruct %389, %int2097152_2163 : (!torch.int, !torch.int) -> !torch.list - %1988 = torch.aten.view %1986, %1987 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1988, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_2164 = torch.constant.int 32 - %int2_2165 = torch.constant.int 2 - %int32_2166 = torch.constant.int 32 - %int8_2167 = torch.constant.int 8 - %int128_2168 = torch.constant.int 128 - %1989 = torch.prim.ListConstruct %389, %int32_2164, %int2_2165, %int32_2166, %int8_2167, %int128_2168 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1990 = torch.aten.view %1988, %1989 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1990, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2169 = torch.constant.int 32 - %int8_2170 = torch.constant.int 8 - %int128_2171 = torch.constant.int 128 - %1991 = torch.prim.ListConstruct %1980, %int32_2169, %int8_2170, %int128_2171 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1992 = torch.aten.view %1990, %1991 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1992, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_2172 = torch.constant.int 4 - %int32_2173 = torch.constant.int 32 - %int8_2174 = torch.constant.int 8 - %int128_2175 = torch.constant.int 128 - %1993 = torch.prim.ListConstruct %int4_2172, %398, %int32_2173, %int8_2174, %int128_2175 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1994 = torch.aten.view %1910, %1993 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1994, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_2176 = torch.constant.int 4 - %1995 = torch.aten.mul.int %int4_2176, %398 : !torch.int, !torch.int -> !torch.int - %int32_2177 = torch.constant.int 32 - %int8_2178 = torch.constant.int 8 - %int128_2179 = torch.constant.int 128 - %1996 = torch.prim.ListConstruct %1995, %int32_2177, %int8_2178, %int128_2179 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1997 = torch.aten.view %1994, %1996 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %1997, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_2180 = torch.constant.int 1 + %1982 = torch.aten.div.Scalar %1981, %int128_2151 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_2152 = torch.constant.float 5.000000e+05 + %1983 = torch.aten.pow.Scalar %float5.000000e05_2152, %1982 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1984 = torch.aten.reciprocal %1983 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_2153 = torch.constant.float 1.000000e+00 + %1985 = torch.aten.mul.Scalar %1984, %float1.000000e00_2153 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %1986 = torch.aten.reciprocal %1985 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_2154 = torch.constant.float 6.2831853071795862 + %1987 = torch.aten.mul.Scalar %1986, %float6.283190e00_2154 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_2155 = torch.constant.float 8.192000e+03 + %1988 = torch.aten.gt.Scalar %1987, %float8.192000e03_2155 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_2156 = torch.constant.int 8 + %1989 = torch.aten.div.Scalar %1985, %int8_2156 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %1990 = torch.aten.where.self %1988, %1989, %1985 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %1991 = torch.aten.reciprocal %1987 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_2157 = torch.constant.int 8192 + %1992 = torch.aten.mul.Scalar %1991, %int8192_2157 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_2158 = torch.constant.int 1 + %int1_2159 = torch.constant.int 1 + %1993 = torch.aten.sub.Scalar %1992, %int1_2158, %int1_2159 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_2160 = torch.constant.int 3 + %1994 = torch.aten.div.Scalar %1993, %int3_2160 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_2161 = torch.constant.int 1 + %int1_2162 = torch.constant.int 1 + %1995 = torch.aten.rsub.Scalar %1994, %int1_2161, %int1_2162 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %1996 = torch.aten.mul.Tensor %1995, %1990 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_2163 = torch.constant.int 8 + %1997 = torch.aten.div.Scalar %1996, %int8_2163 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %1998 = torch.aten.mul.Tensor %1994, %1990 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_2164 = torch.constant.int 1 + %1999 = torch.aten.add.Tensor %1997, %1998, %int1_2164 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_2165 = torch.constant.float 2.048000e+03 + %2000 = torch.aten.lt.Scalar %1987, %float2.048000e03_2165 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2001 = torch.aten.bitwise_not %2000 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_2166 = torch.constant.float 8.192000e+03 + %2002 = torch.aten.gt.Scalar %1987, %float8.192000e03_2166 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2003 = torch.aten.bitwise_not %2002 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2004 = torch.aten.mul.Tensor %2001, %2003 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2005 = torch.aten.where.self %2004, %1999, %1990 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2006 = torch.prim.ListConstruct %2005, %2005 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_2167 = torch.constant.int -1 + %2007 = torch.aten.cat %2006, %int-1_2167 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_2168 = torch.constant.int 6 + %2008 = torch.prims.convert_element_type %2007, %int6_2168 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_2169 = torch.constant.int 1 + %2009 = torch.aten.unsqueeze %1979, %int1_2169 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_2170 = torch.constant.int 6 + %2010 = torch.prims.convert_element_type %2009, %int6_2170 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_2171 = torch.constant.int 0 + %2011 = torch.aten.unsqueeze %2008, %int0_2171 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_2172 = torch.constant.int 6 + %2012 = torch.prims.convert_element_type %2011, %int6_2172 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %2013 = torch.aten.mul.Tensor %2010, %2012 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %2014 = torch.aten.cos %2013 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_2173 = torch.constant.int 5 + %2015 = torch.prims.convert_element_type %2014, %int5_2173 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %2016 = torch.aten.sin %2013 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_2174 = torch.constant.int 5 + %2017 = torch.prims.convert_element_type %2016, %int5_2174 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_2175 = torch.constant.int 0 + %int0_2176 = torch.constant.int 0 + %int1_2177 = torch.constant.int 1 + %2018 = torch.aten.slice.Tensor %2015, %int0_2175, %int0_2176, %298, %int1_2177 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2018, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_2178 = torch.constant.int 1 + %int0_2179 = torch.constant.int 0 + %int9223372036854775807_2180 = torch.constant.int 9223372036854775807 %int1_2181 = torch.constant.int 1 - %1998 = torch.aten.add.Scalar %1968, %int1_2180, %int1_2181 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1998, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_2182 = torch.constant.int 4 - %1999 = torch.aten.mul.int %int4_2182, %398 : !torch.int, !torch.int -> !torch.int - %2000 = torch.prim.ListConstruct %1999 : (!torch.int) -> !torch.list - %2001 = torch.aten.view %1998, %2000 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2001, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %2002 = torch.prim.ListConstruct %2001 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_2183 = torch.constant.bool false - %2003 = torch.aten.index_put %1992, %2002, %1997, %false_2183 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2003, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_2184 = torch.constant.int 32 - %int2_2185 = torch.constant.int 2 - %int32_2186 = torch.constant.int 32 - %int8_2187 = torch.constant.int 8 - %int128_2188 = torch.constant.int 128 - %2004 = torch.prim.ListConstruct %389, %int32_2184, %int2_2185, %int32_2186, %int8_2187, %int128_2188 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2005 = torch.aten.view %2003, %2004 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2005, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2189 = torch.constant.int 2097152 - %2006 = torch.prim.ListConstruct %389, %int2097152_2189 : (!torch.int, !torch.int) -> !torch.list - %2007 = torch.aten.view %2005, %2006 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2007, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_2190 = torch.constant.int -2 - %2008 = torch.aten.unsqueeze %1966, %int-2_2190 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2008, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_2191 = torch.constant.int 4 - %int8_2192 = torch.constant.int 8 - %int4_2193 = torch.constant.int 4 - %int128_2194 = torch.constant.int 128 - %2009 = torch.prim.ListConstruct %int4_2191, %1951, %int8_2192, %int4_2193, %int128_2194 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2195 = torch.constant.bool false - %2010 = torch.aten.expand %2008, %2009, %false_2195 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2010, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %2019 = torch.aten.slice.Tensor %2018, %int1_2178, %int0_2179, %int9223372036854775807_2180, %int1_2181 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2019, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_2182 = torch.constant.int 0 + %int0_2183 = torch.constant.int 0 + %int1_2184 = torch.constant.int 1 + %2020 = torch.aten.slice.Tensor %2017, %int0_2182, %int0_2183, %298, %int1_2184 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2020, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_2185 = torch.constant.int 1 + %int0_2186 = torch.constant.int 0 + %int9223372036854775807_2187 = torch.constant.int 9223372036854775807 + %int1_2188 = torch.constant.int 1 + %2021 = torch.aten.slice.Tensor %2020, %int1_2185, %int0_2186, %int9223372036854775807_2187, %int1_2188 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2021, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_2189 = torch.constant.int 0 + %2022 = torch.aten.unsqueeze %2019, %int0_2189 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2022, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_2190 = torch.constant.int 1 + %int0_2191 = torch.constant.int 0 + %int9223372036854775807_2192 = torch.constant.int 9223372036854775807 + %int1_2193 = torch.constant.int 1 + %2023 = torch.aten.slice.Tensor %2022, %int1_2190, %int0_2191, %int9223372036854775807_2192, %int1_2193 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2023, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_2194 = torch.constant.int 2 + %2024 = torch.aten.unsqueeze %2023, %int2_2194 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2024, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_2195 = torch.constant.int 3 %int0_2196 = torch.constant.int 0 - %2011 = torch.aten.clone %2010, %int0_2196 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2011, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_2197 = torch.constant.int 4 - %int32_2198 = torch.constant.int 32 - %int128_2199 = torch.constant.int 128 - %2012 = torch.prim.ListConstruct %int4_2197, %1951, %int32_2198, %int128_2199 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2013 = torch.aten._unsafe_view %2011, %2012 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2013, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_2200 = torch.constant.int -2 - %2014 = torch.aten.unsqueeze %1910, %int-2_2200 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2014, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int9223372036854775807_2197 = torch.constant.int 9223372036854775807 + %int1_2198 = torch.constant.int 1 + %2025 = torch.aten.slice.Tensor %2024, %int3_2195, %int0_2196, %int9223372036854775807_2197, %int1_2198 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2025, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_2199 = torch.constant.int 4 + %int1_2200 = torch.constant.int 1 %int1_2201 = torch.constant.int 1 - %2015 = torch.aten.size.int %1904, %int1_2201 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_2202 = torch.constant.int 4 - %int8_2203 = torch.constant.int 8 - %int4_2204 = torch.constant.int 4 - %int128_2205 = torch.constant.int 128 - %2016 = torch.prim.ListConstruct %int4_2202, %2015, %int8_2203, %int4_2204, %int128_2205 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2206 = torch.constant.bool false - %2017 = torch.aten.expand %2014, %2016, %false_2206 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2017, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_2207 = torch.constant.int 0 - %2018 = torch.aten.clone %2017, %int0_2207 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2018, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_2208 = torch.constant.int 4 - %int32_2209 = torch.constant.int 32 - %int128_2210 = torch.constant.int 128 - %2019 = torch.prim.ListConstruct %int4_2208, %2015, %int32_2209, %int128_2210 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2020 = torch.aten._unsafe_view %2018, %2019 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2020, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_2211 = torch.constant.int 1 - %int2_2212 = torch.constant.int 2 - %2021 = torch.aten.transpose.int %1938, %int1_2211, %int2_2212 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2021, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_2213 = torch.constant.int 1 - %int2_2214 = torch.constant.int 2 - %2022 = torch.aten.transpose.int %2013, %int1_2213, %int2_2214 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2022, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_2202 = torch.constant.int 1 + %2026 = torch.prim.ListConstruct %int4_2199, %int1_2200, %int1_2201, %int1_2202 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2027 = torch.aten.repeat %2025, %2026 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2027, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_2203 = torch.constant.int 0 + %2028 = torch.aten.unsqueeze %2021, %int0_2203 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2028, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_2204 = torch.constant.int 1 + %int0_2205 = torch.constant.int 0 + %int9223372036854775807_2206 = torch.constant.int 9223372036854775807 + %int1_2207 = torch.constant.int 1 + %2029 = torch.aten.slice.Tensor %2028, %int1_2204, %int0_2205, %int9223372036854775807_2206, %int1_2207 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2029, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_2208 = torch.constant.int 2 + %2030 = torch.aten.unsqueeze %2029, %int2_2208 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2030, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_2209 = torch.constant.int 3 + %int0_2210 = torch.constant.int 0 + %int9223372036854775807_2211 = torch.constant.int 9223372036854775807 + %int1_2212 = torch.constant.int 1 + %2031 = torch.aten.slice.Tensor %2030, %int3_2209, %int0_2210, %int9223372036854775807_2211, %int1_2212 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2031, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_2213 = torch.constant.int 4 + %int1_2214 = torch.constant.int 1 %int1_2215 = torch.constant.int 1 - %int2_2216 = torch.constant.int 2 - %2023 = torch.aten.transpose.int %2020, %int1_2215, %int2_2216 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2023, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_2217 = torch.constant.float 0.000000e+00 - %true_2218 = torch.constant.bool true - %none_2219 = torch.constant.none - %none_2220 = torch.constant.none - %2024:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2021, %2022, %2023, %float0.000000e00_2217, %true_2218, %none_2219, %none_2220) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %2024#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_2221 = torch.constant.int 1 - %int2_2222 = torch.constant.int 2 - %2025 = torch.aten.transpose.int %2024#0, %int1_2221, %int2_2222 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2025, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_2223 = torch.constant.int 4 - %int4096_2224 = torch.constant.int 4096 - %2026 = torch.prim.ListConstruct %int4_2223, %1923, %int4096_2224 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2027 = torch.aten.view %2025, %2026 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2027, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_2225 = torch.constant.int -2 - %int-1_2226 = torch.constant.int -1 - %2028 = torch.aten.transpose.int %77, %int-2_2225, %int-1_2226 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_2227 = torch.constant.int 4 - %2029 = torch.aten.mul.int %int4_2227, %1923 : !torch.int, !torch.int -> !torch.int - %int4096_2228 = torch.constant.int 4096 - %2030 = torch.prim.ListConstruct %2029, %int4096_2228 : (!torch.int, !torch.int) -> !torch.list - %2031 = torch.aten.view %2027, %2030 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2031, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2032 = torch.aten.mm %2031, %2028 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2032, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_2229 = torch.constant.int 4 - %int4096_2230 = torch.constant.int 4096 - %2033 = torch.prim.ListConstruct %int4_2229, %1923, %int4096_2230 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2034 = torch.aten.view %2032, %2033 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2034, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_2231 = torch.constant.int 1 - %2035 = torch.aten.add.Tensor %1873, %2034, %int1_2231 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2035, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_2232 = torch.constant.int 6 - %2036 = torch.prims.convert_element_type %2035, %int6_2232 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2036, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_2233 = torch.constant.int 2 - %2037 = torch.aten.pow.Tensor_Scalar %2036, %int2_2233 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2037, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_2234 = torch.constant.int -1 - %2038 = torch.prim.ListConstruct %int-1_2234 : (!torch.int) -> !torch.list - %true_2235 = torch.constant.bool true + %int1_2216 = torch.constant.int 1 + %2032 = torch.prim.ListConstruct %int4_2213, %int1_2214, %int1_2215, %int1_2216 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2033 = torch.aten.repeat %2031, %2032 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2033, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %2034 = torch.aten.mul.Tensor %1974, %2027 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2034, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_2217 = torch.constant.int 3 + %int0_2218 = torch.constant.int 0 + %int64_2219 = torch.constant.int 64 + %int1_2220 = torch.constant.int 1 + %2035 = torch.aten.slice.Tensor %1974, %int3_2217, %int0_2218, %int64_2219, %int1_2220 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %2035, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_2221 = torch.constant.int 3 + %int64_2222 = torch.constant.int 64 + %int9223372036854775807_2223 = torch.constant.int 9223372036854775807 + %int1_2224 = torch.constant.int 1 + %2036 = torch.aten.slice.Tensor %1974, %int3_2221, %int64_2222, %int9223372036854775807_2223, %int1_2224 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %2036, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %2037 = torch.aten.neg %2036 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %2037, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %2038 = torch.prim.ListConstruct %2037, %2035 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_2225 = torch.constant.int -1 + %2039 = torch.aten.cat %2038, %int-1_2225 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2039, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %2040 = torch.aten.mul.Tensor %2039, %2033 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2040, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_2226 = torch.constant.int 1 + %2041 = torch.aten.add.Tensor %2034, %2040, %int1_2226 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2041, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_2227 = torch.constant.int 131072 + %none_2228 = torch.constant.none + %none_2229 = torch.constant.none + %cpu_2230 = torch.constant.device "cpu" + %false_2231 = torch.constant.bool false + %2042 = torch.aten.arange %int131072_2227, %none_2228, %none_2229, %cpu_2230, %false_2231 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_2232 = torch.constant.int 0 + %int128_2233 = torch.constant.int 128 + %int2_2234 = torch.constant.int 2 + %int4_2235 = torch.constant.int 4 %none_2236 = torch.constant.none - %2039 = torch.aten.mean.dim %2037, %2038, %true_2235, %none_2236 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2039, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_2237 = torch.constant.float 9.9999997473787516E-6 - %int1_2238 = torch.constant.int 1 - %2040 = torch.aten.add.Scalar %2039, %float9.999990e-06_2237, %int1_2238 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2040, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2041 = torch.aten.rsqrt %2040 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2041, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2042 = torch.aten.mul.Tensor %2036, %2041 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2042, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2239 = torch.constant.int 5 - %2043 = torch.prims.convert_element_type %2042, %int5_2239 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2043, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %2044 = torch.aten.mul.Tensor %78, %2043 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2044, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2240 = torch.constant.int 5 - %2045 = torch.prims.convert_element_type %2044, %int5_2240 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2045, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_2241 = torch.constant.int -2 - %int-1_2242 = torch.constant.int -1 - %2046 = torch.aten.transpose.int %79, %int-2_2241, %int-1_2242 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_2243 = torch.constant.int 4 - %2047 = torch.aten.mul.int %int4_2243, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2244 = torch.constant.int 4096 - %2048 = torch.prim.ListConstruct %2047, %int4096_2244 : (!torch.int, !torch.int) -> !torch.list - %2049 = torch.aten.view %2045, %2048 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2049, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2050 = torch.aten.mm %2049, %2046 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2050, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_2245 = torch.constant.int 4 - %int14336_2246 = torch.constant.int 14336 - %2051 = torch.prim.ListConstruct %int4_2245, %306, %int14336_2246 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2052 = torch.aten.view %2050, %2051 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2052, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %2053 = torch.aten.silu %2052 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2053, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_2247 = torch.constant.int -2 - %int-1_2248 = torch.constant.int -1 - %2054 = torch.aten.transpose.int %80, %int-2_2247, %int-1_2248 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_2249 = torch.constant.int 4 - %2055 = torch.aten.mul.int %int4_2249, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2250 = torch.constant.int 4096 - %2056 = torch.prim.ListConstruct %2055, %int4096_2250 : (!torch.int, !torch.int) -> !torch.list - %2057 = torch.aten.view %2045, %2056 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2057, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2058 = torch.aten.mm %2057, %2054 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2058, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_2251 = torch.constant.int 4 - %int14336_2252 = torch.constant.int 14336 - %2059 = torch.prim.ListConstruct %int4_2251, %306, %int14336_2252 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2060 = torch.aten.view %2058, %2059 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2060, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %2061 = torch.aten.mul.Tensor %2053, %2060 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2061, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_2253 = torch.constant.int -2 - %int-1_2254 = torch.constant.int -1 - %2062 = torch.aten.transpose.int %81, %int-2_2253, %int-1_2254 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_2255 = torch.constant.int 1 - %2063 = torch.aten.size.int %2052, %int1_2255 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_2256 = torch.constant.int 4 - %2064 = torch.aten.mul.int %int4_2256, %2063 : !torch.int, !torch.int -> !torch.int - %int14336_2257 = torch.constant.int 14336 - %2065 = torch.prim.ListConstruct %2064, %int14336_2257 : (!torch.int, !torch.int) -> !torch.list - %2066 = torch.aten.view %2061, %2065 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2066, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %2067 = torch.aten.mm %2066, %2062 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2067, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_2258 = torch.constant.int 4 - %int4096_2259 = torch.constant.int 4096 - %2068 = torch.prim.ListConstruct %int4_2258, %2063, %int4096_2259 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2069 = torch.aten.view %2067, %2068 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2069, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_2260 = torch.constant.int 1 - %2070 = torch.aten.add.Tensor %2035, %2069, %int1_2260 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2070, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %cpu_2237 = torch.constant.device "cpu" + %false_2238 = torch.constant.bool false + %2043 = torch.aten.arange.start_step %int0_2232, %int128_2233, %int2_2234, %int4_2235, %none_2236, %cpu_2237, %false_2238 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_2239 = torch.constant.int 6 + %2044 = torch.prims.convert_element_type %2043, %int6_2239 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_2240 = torch.constant.int 128 + %2045 = torch.aten.div.Scalar %2044, %int128_2240 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_2241 = torch.constant.float 5.000000e+05 + %2046 = torch.aten.pow.Scalar %float5.000000e05_2241, %2045 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2047 = torch.aten.reciprocal %2046 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_2242 = torch.constant.float 1.000000e+00 + %2048 = torch.aten.mul.Scalar %2047, %float1.000000e00_2242 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %2049 = torch.aten.reciprocal %2048 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_2243 = torch.constant.float 6.2831853071795862 + %2050 = torch.aten.mul.Scalar %2049, %float6.283190e00_2243 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_2244 = torch.constant.float 8.192000e+03 + %2051 = torch.aten.gt.Scalar %2050, %float8.192000e03_2244 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_2245 = torch.constant.int 8 + %2052 = torch.aten.div.Scalar %2048, %int8_2245 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %2053 = torch.aten.where.self %2051, %2052, %2048 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2054 = torch.aten.reciprocal %2050 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_2246 = torch.constant.int 8192 + %2055 = torch.aten.mul.Scalar %2054, %int8192_2246 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_2247 = torch.constant.int 1 + %int1_2248 = torch.constant.int 1 + %2056 = torch.aten.sub.Scalar %2055, %int1_2247, %int1_2248 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_2249 = torch.constant.int 3 + %2057 = torch.aten.div.Scalar %2056, %int3_2249 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_2250 = torch.constant.int 1 + %int1_2251 = torch.constant.int 1 + %2058 = torch.aten.rsub.Scalar %2057, %int1_2250, %int1_2251 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %2059 = torch.aten.mul.Tensor %2058, %2053 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_2252 = torch.constant.int 8 + %2060 = torch.aten.div.Scalar %2059, %int8_2252 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %2061 = torch.aten.mul.Tensor %2057, %2053 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_2253 = torch.constant.int 1 + %2062 = torch.aten.add.Tensor %2060, %2061, %int1_2253 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_2254 = torch.constant.float 2.048000e+03 + %2063 = torch.aten.lt.Scalar %2050, %float2.048000e03_2254 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2064 = torch.aten.bitwise_not %2063 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_2255 = torch.constant.float 8.192000e+03 + %2065 = torch.aten.gt.Scalar %2050, %float8.192000e03_2255 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2066 = torch.aten.bitwise_not %2065 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2067 = torch.aten.mul.Tensor %2064, %2066 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2068 = torch.aten.where.self %2067, %2062, %2053 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2069 = torch.prim.ListConstruct %2068, %2068 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_2256 = torch.constant.int -1 + %2070 = torch.aten.cat %2069, %int-1_2256 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_2257 = torch.constant.int 6 + %2071 = torch.prims.convert_element_type %2070, %int6_2257 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_2258 = torch.constant.int 1 + %2072 = torch.aten.unsqueeze %2042, %int1_2258 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_2259 = torch.constant.int 6 + %2073 = torch.prims.convert_element_type %2072, %int6_2259 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_2260 = torch.constant.int 0 + %2074 = torch.aten.unsqueeze %2071, %int0_2260 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> %int6_2261 = torch.constant.int 6 - %2071 = torch.prims.convert_element_type %2070, %int6_2261 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2071, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_2262 = torch.constant.int 2 - %2072 = torch.aten.pow.Tensor_Scalar %2071, %int2_2262 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2072, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_2263 = torch.constant.int -1 - %2073 = torch.prim.ListConstruct %int-1_2263 : (!torch.int) -> !torch.list - %true_2264 = torch.constant.bool true - %none_2265 = torch.constant.none - %2074 = torch.aten.mean.dim %2072, %2073, %true_2264, %none_2265 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2074, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_2266 = torch.constant.float 9.9999997473787516E-6 + %2075 = torch.prims.convert_element_type %2074, %int6_2261 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %2076 = torch.aten.mul.Tensor %2073, %2075 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %2077 = torch.aten.cos %2076 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_2262 = torch.constant.int 5 + %2078 = torch.prims.convert_element_type %2077, %int5_2262 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %2079 = torch.aten.sin %2076 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_2263 = torch.constant.int 5 + %2080 = torch.prims.convert_element_type %2079, %int5_2263 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_2264 = torch.constant.int 0 + %int0_2265 = torch.constant.int 0 + %int1_2266 = torch.constant.int 1 + %2081 = torch.aten.slice.Tensor %2078, %int0_2264, %int0_2265, %298, %int1_2266 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2081, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_2267 = torch.constant.int 1 - %2075 = torch.aten.add.Scalar %2074, %float9.999990e-06_2266, %int1_2267 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2075, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2076 = torch.aten.rsqrt %2075 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2076, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2077 = torch.aten.mul.Tensor %2071, %2076 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2077, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2268 = torch.constant.int 5 - %2078 = torch.prims.convert_element_type %2077, %int5_2268 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2078, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %2079 = torch.aten.mul.Tensor %82, %2078 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2079, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2269 = torch.constant.int 5 - %2080 = torch.prims.convert_element_type %2079, %int5_2269 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2080, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_2270 = torch.constant.int -2 - %int-1_2271 = torch.constant.int -1 - %2081 = torch.aten.transpose.int %83, %int-2_2270, %int-1_2271 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_2272 = torch.constant.int 4 - %2082 = torch.aten.mul.int %int4_2272, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2273 = torch.constant.int 4096 - %2083 = torch.prim.ListConstruct %2082, %int4096_2273 : (!torch.int, !torch.int) -> !torch.list - %2084 = torch.aten.view %2080, %2083 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2084, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2085 = torch.aten.mm %2084, %2081 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2085, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_2274 = torch.constant.int 4 - %int4096_2275 = torch.constant.int 4096 - %2086 = torch.prim.ListConstruct %int4_2274, %306, %int4096_2275 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2087 = torch.aten.view %2085, %2086 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2087, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_2276 = torch.constant.int -2 - %int-1_2277 = torch.constant.int -1 - %2088 = torch.aten.transpose.int %84, %int-2_2276, %int-1_2277 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_2278 = torch.constant.int 4 - %2089 = torch.aten.mul.int %int4_2278, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2279 = torch.constant.int 4096 - %2090 = torch.prim.ListConstruct %2089, %int4096_2279 : (!torch.int, !torch.int) -> !torch.list - %2091 = torch.aten.view %2080, %2090 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2091, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2092 = torch.aten.mm %2091, %2088 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %2092, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_2280 = torch.constant.int 4 - %int1024_2281 = torch.constant.int 1024 - %2093 = torch.prim.ListConstruct %int4_2280, %306, %int1024_2281 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2094 = torch.aten.view %2092, %2093 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %2094, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_2282 = torch.constant.int -2 - %int-1_2283 = torch.constant.int -1 - %2095 = torch.aten.transpose.int %85, %int-2_2282, %int-1_2283 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_2284 = torch.constant.int 4 - %2096 = torch.aten.mul.int %int4_2284, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2285 = torch.constant.int 4096 - %2097 = torch.prim.ListConstruct %2096, %int4096_2285 : (!torch.int, !torch.int) -> !torch.list - %2098 = torch.aten.view %2080, %2097 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2098, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2099 = torch.aten.mm %2098, %2095 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %2099, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_2286 = torch.constant.int 4 - %int1024_2287 = torch.constant.int 1024 - %2100 = torch.prim.ListConstruct %int4_2286, %306, %int1024_2287 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2101 = torch.aten.view %2099, %2100 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %2101, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int0_2268 = torch.constant.int 0 + %int9223372036854775807_2269 = torch.constant.int 9223372036854775807 + %int1_2270 = torch.constant.int 1 + %2082 = torch.aten.slice.Tensor %2081, %int1_2267, %int0_2268, %int9223372036854775807_2269, %int1_2270 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2082, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_2271 = torch.constant.int 0 + %int0_2272 = torch.constant.int 0 + %int1_2273 = torch.constant.int 1 + %2083 = torch.aten.slice.Tensor %2080, %int0_2271, %int0_2272, %298, %int1_2273 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2083, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_2274 = torch.constant.int 1 + %int0_2275 = torch.constant.int 0 + %int9223372036854775807_2276 = torch.constant.int 9223372036854775807 + %int1_2277 = torch.constant.int 1 + %2084 = torch.aten.slice.Tensor %2083, %int1_2274, %int0_2275, %int9223372036854775807_2276, %int1_2277 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2084, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_2278 = torch.constant.int 0 + %2085 = torch.aten.unsqueeze %2082, %int0_2278 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2085, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_2279 = torch.constant.int 1 + %int0_2280 = torch.constant.int 0 + %int9223372036854775807_2281 = torch.constant.int 9223372036854775807 + %int1_2282 = torch.constant.int 1 + %2086 = torch.aten.slice.Tensor %2085, %int1_2279, %int0_2280, %int9223372036854775807_2281, %int1_2282 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2086, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_2283 = torch.constant.int 2 + %2087 = torch.aten.unsqueeze %2086, %int2_2283 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2087, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_2284 = torch.constant.int 3 + %int0_2285 = torch.constant.int 0 + %int9223372036854775807_2286 = torch.constant.int 9223372036854775807 + %int1_2287 = torch.constant.int 1 + %2088 = torch.aten.slice.Tensor %2087, %int3_2284, %int0_2285, %int9223372036854775807_2286, %int1_2287 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2088, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> %int4_2288 = torch.constant.int 4 - %int32_2289 = torch.constant.int 32 - %int128_2290 = torch.constant.int 128 - %2102 = torch.prim.ListConstruct %int4_2288, %306, %int32_2289, %int128_2290 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2103 = torch.aten.view %2087, %2102 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2103, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_2291 = torch.constant.int 4 - %int8_2292 = torch.constant.int 8 - %int128_2293 = torch.constant.int 128 - %2104 = torch.prim.ListConstruct %int4_2291, %306, %int8_2292, %int128_2293 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2105 = torch.aten.view %2094, %2104 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2105, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_2294 = torch.constant.int 4 - %int8_2295 = torch.constant.int 8 - %int128_2296 = torch.constant.int 128 - %2106 = torch.prim.ListConstruct %int4_2294, %306, %int8_2295, %int128_2296 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2107 = torch.aten.view %2101, %2106 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2107, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_2297 = torch.constant.int 131072 - %none_2298 = torch.constant.none - %none_2299 = torch.constant.none - %cpu_2300 = torch.constant.device "cpu" - %false_2301 = torch.constant.bool false - %2108 = torch.aten.arange %int131072_2297, %none_2298, %none_2299, %cpu_2300, %false_2301 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_2302 = torch.constant.int 0 - %int128_2303 = torch.constant.int 128 - %none_2304 = torch.constant.none - %none_2305 = torch.constant.none - %cpu_2306 = torch.constant.device "cpu" - %false_2307 = torch.constant.bool false - %2109 = torch.aten.arange.start %int0_2302, %int128_2303, %none_2304, %none_2305, %cpu_2306, %false_2307 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_2308 = torch.constant.int 2 - %2110 = torch.aten.floor_divide.Scalar %2109, %int2_2308 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_2309 = torch.constant.int 6 - %2111 = torch.prims.convert_element_type %2110, %int6_2309 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_2310 = torch.constant.int 128 - %2112 = torch.aten.div.Scalar %2111, %int128_2310 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_2311 = torch.constant.float 2.000000e+00 - %2113 = torch.aten.mul.Scalar %2112, %float2.000000e00_2311 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_2312 = torch.constant.float 5.000000e+05 - %2114 = torch.aten.pow.Scalar %float5.000000e05_2312, %2113 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %2115 = torch.aten.reciprocal %2114 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_2313 = torch.constant.float 1.000000e+00 - %2116 = torch.aten.mul.Scalar %2115, %float1.000000e00_2313 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_2314 = torch.constant.int 1 - %2117 = torch.aten.unsqueeze %2108, %int1_2314 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_2315 = torch.constant.int 0 - %2118 = torch.aten.unsqueeze %2116, %int0_2315 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %2119 = torch.aten.mul.Tensor %2117, %2118 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_2316 = torch.constant.int 1 - %2120 = torch.aten.size.int %2087, %int1_2316 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_2317 = torch.constant.int 0 - %2121 = torch.aten.add.int %int0_2317, %2120 : !torch.int, !torch.int -> !torch.int - %int0_2318 = torch.constant.int 0 - %int0_2319 = torch.constant.int 0 - %int1_2320 = torch.constant.int 1 - %2122 = torch.aten.slice.Tensor %2119, %int0_2318, %int0_2319, %2121, %int1_2320 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2122, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %int1_2289 = torch.constant.int 1 + %int1_2290 = torch.constant.int 1 + %int1_2291 = torch.constant.int 1 + %2089 = torch.prim.ListConstruct %int4_2288, %int1_2289, %int1_2290, %int1_2291 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2090 = torch.aten.repeat %2088, %2089 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2090, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_2292 = torch.constant.int 0 + %2091 = torch.aten.unsqueeze %2084, %int0_2292 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2091, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_2293 = torch.constant.int 1 + %int0_2294 = torch.constant.int 0 + %int9223372036854775807_2295 = torch.constant.int 9223372036854775807 + %int1_2296 = torch.constant.int 1 + %2092 = torch.aten.slice.Tensor %2091, %int1_2293, %int0_2294, %int9223372036854775807_2295, %int1_2296 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2092, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_2297 = torch.constant.int 2 + %2093 = torch.aten.unsqueeze %2092, %int2_2297 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2093, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_2298 = torch.constant.int 3 + %int0_2299 = torch.constant.int 0 + %int9223372036854775807_2300 = torch.constant.int 9223372036854775807 + %int1_2301 = torch.constant.int 1 + %2094 = torch.aten.slice.Tensor %2093, %int3_2298, %int0_2299, %int9223372036854775807_2300, %int1_2301 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2094, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_2302 = torch.constant.int 4 + %int1_2303 = torch.constant.int 1 + %int1_2304 = torch.constant.int 1 + %int1_2305 = torch.constant.int 1 + %2095 = torch.prim.ListConstruct %int4_2302, %int1_2303, %int1_2304, %int1_2305 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2096 = torch.aten.repeat %2094, %2095 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2096, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %2097 = torch.aten.mul.Tensor %1976, %2090 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2097, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_2306 = torch.constant.int 3 + %int0_2307 = torch.constant.int 0 + %int64_2308 = torch.constant.int 64 + %int1_2309 = torch.constant.int 1 + %2098 = torch.aten.slice.Tensor %1976, %int3_2306, %int0_2307, %int64_2308, %int1_2309 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %2098, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_2310 = torch.constant.int 3 + %int64_2311 = torch.constant.int 64 + %int9223372036854775807_2312 = torch.constant.int 9223372036854775807 + %int1_2313 = torch.constant.int 1 + %2099 = torch.aten.slice.Tensor %1976, %int3_2310, %int64_2311, %int9223372036854775807_2312, %int1_2313 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %2099, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %2100 = torch.aten.neg %2099 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %2100, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %2101 = torch.prim.ListConstruct %2100, %2098 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_2314 = torch.constant.int -1 + %2102 = torch.aten.cat %2101, %int-1_2314 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2102, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %2103 = torch.aten.mul.Tensor %2102, %2096 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2103, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_2315 = torch.constant.int 1 + %2104 = torch.aten.add.Tensor %2097, %2103, %int1_2315 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2104, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_2316 = torch.constant.int 32 + %2105 = torch.aten.mul.Scalar %arg2, %int32_2316 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2105, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int6_2317 = torch.constant.int 6 + %int1_2318 = torch.constant.int 1 + %2106 = torch.aten.add.Scalar %2105, %int6_2317, %int1_2318 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2106, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_2319 = torch.constant.int 2 + %2107 = torch.aten.mul.Scalar %2106, %int2_2319 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2107, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_2320 = torch.constant.int 0 %int1_2321 = torch.constant.int 1 - %int0_2322 = torch.constant.int 0 - %int9223372036854775807_2323 = torch.constant.int 9223372036854775807 - %int1_2324 = torch.constant.int 1 - %2123 = torch.aten.slice.Tensor %2122, %int1_2321, %int0_2322, %int9223372036854775807_2323, %int1_2324 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2123, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_2325 = torch.constant.int 1 - %int0_2326 = torch.constant.int 0 - %int9223372036854775807_2327 = torch.constant.int 9223372036854775807 - %int1_2328 = torch.constant.int 1 - %2124 = torch.aten.slice.Tensor %2123, %int1_2325, %int0_2326, %int9223372036854775807_2327, %int1_2328 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2124, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_2329 = torch.constant.int 0 - %2125 = torch.aten.unsqueeze %2124, %int0_2329 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2125, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_2330 = torch.constant.int 1 - %int0_2331 = torch.constant.int 0 - %int9223372036854775807_2332 = torch.constant.int 9223372036854775807 - %int1_2333 = torch.constant.int 1 - %2126 = torch.aten.slice.Tensor %2125, %int1_2330, %int0_2331, %int9223372036854775807_2332, %int1_2333 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2126, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_2334 = torch.constant.int 2 - %int0_2335 = torch.constant.int 0 - %int9223372036854775807_2336 = torch.constant.int 9223372036854775807 - %int1_2337 = torch.constant.int 1 - %2127 = torch.aten.slice.Tensor %2126, %int2_2334, %int0_2335, %int9223372036854775807_2336, %int1_2337 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2127, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_2338 = torch.constant.int 4 - %int1_2339 = torch.constant.int 1 - %int1_2340 = torch.constant.int 1 - %2128 = torch.prim.ListConstruct %int4_2338, %int1_2339, %int1_2340 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2129 = torch.aten.repeat %2127, %2128 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %2129, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_2341 = torch.constant.int 6 - %2130 = torch.prims.convert_element_type %2103, %int6_2341 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %2130, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %2131 = torch_c.to_builtin_tensor %2130 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %2132 = torch_c.to_builtin_tensor %2129 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %2133 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%2131, %2132) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %2134 = torch_c.from_builtin_tensor %2133 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %2134, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_2342 = torch.constant.int 5 - %2135 = torch.prims.convert_element_type %2134, %int5_2342 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2135, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_2343 = torch.constant.int 131072 - %none_2344 = torch.constant.none - %none_2345 = torch.constant.none - %cpu_2346 = torch.constant.device "cpu" - %false_2347 = torch.constant.bool false - %2136 = torch.aten.arange %int131072_2343, %none_2344, %none_2345, %cpu_2346, %false_2347 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_2348 = torch.constant.int 0 - %int128_2349 = torch.constant.int 128 - %none_2350 = torch.constant.none - %none_2351 = torch.constant.none - %cpu_2352 = torch.constant.device "cpu" - %false_2353 = torch.constant.bool false - %2137 = torch.aten.arange.start %int0_2348, %int128_2349, %none_2350, %none_2351, %cpu_2352, %false_2353 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_2354 = torch.constant.int 2 - %2138 = torch.aten.floor_divide.Scalar %2137, %int2_2354 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_2355 = torch.constant.int 6 - %2139 = torch.prims.convert_element_type %2138, %int6_2355 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_2356 = torch.constant.int 128 - %2140 = torch.aten.div.Scalar %2139, %int128_2356 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_2357 = torch.constant.float 2.000000e+00 - %2141 = torch.aten.mul.Scalar %2140, %float2.000000e00_2357 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_2358 = torch.constant.float 5.000000e+05 - %2142 = torch.aten.pow.Scalar %float5.000000e05_2358, %2141 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %2143 = torch.aten.reciprocal %2142 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_2359 = torch.constant.float 1.000000e+00 - %2144 = torch.aten.mul.Scalar %2143, %float1.000000e00_2359 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %2108 = torch.aten.add.Scalar %2107, %int0_2320, %int1_2321 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2108, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %2109 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %2110 = torch.aten.view %2108, %2109 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %2110, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_2322 = torch.constant.int 4 + %int32_2323 = torch.constant.int 32 + %int8_2324 = torch.constant.int 8 + %int128_2325 = torch.constant.int 128 + %2111 = torch.prim.ListConstruct %int4_2322, %296, %int32_2323, %int8_2324, %int128_2325 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2112 = torch.aten.view %2104, %2111 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2112, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_2326 = torch.constant.int 32 + %int8_2327 = torch.constant.int 8 + %int128_2328 = torch.constant.int 128 + %2113 = torch.prim.ListConstruct %504, %int32_2326, %int8_2327, %int128_2328 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2114 = torch.aten.view %2112, %2113 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %2114, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_2329 = torch.constant.int 1 + %int2_2330 = torch.constant.int 2 + %2115 = torch.aten.transpose.int %2114, %int1_2329, %int2_2330 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2115, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_2331 = torch.constant.int 5 + %2116 = torch.prims.convert_element_type %2115, %int5_2331 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2116, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_2332 = torch.constant.int 32 + %int2_2333 = torch.constant.int 2 + %int8_2334 = torch.constant.int 8 + %int32_2335 = torch.constant.int 32 + %int128_2336 = torch.constant.int 128 + %2117 = torch.prim.ListConstruct %297, %int32_2332, %int2_2333, %int8_2334, %int32_2335, %int128_2336 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2118 = torch.aten.view %1880, %2117 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2118, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_2337 = torch.constant.int 8 + %int32_2338 = torch.constant.int 32 + %int128_2339 = torch.constant.int 128 + %2119 = torch.prim.ListConstruct %497, %int8_2337, %int32_2338, %int128_2339 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2120 = torch.aten.view %2118, %2119 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2120, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %2121 = torch.prim.ListConstruct %2110 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_2340 = torch.constant.bool false + %2122 = torch.aten.index_put %2120, %2121, %2116, %false_2340 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2122, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_2341 = torch.constant.int 32 + %int2_2342 = torch.constant.int 2 + %int8_2343 = torch.constant.int 8 + %int32_2344 = torch.constant.int 32 + %int128_2345 = torch.constant.int 128 + %2123 = torch.prim.ListConstruct %297, %int32_2341, %int2_2342, %int8_2343, %int32_2344, %int128_2345 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2124 = torch.aten.view %2122, %2123 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2124, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_2346 = torch.constant.int 2097152 + %2125 = torch.prim.ListConstruct %297, %int2097152_2346 : (!torch.int, !torch.int) -> !torch.list + %2126 = torch.aten.view %2124, %2125 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2126, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_2347 = torch.constant.int 32 + %int2_2348 = torch.constant.int 2 + %int8_2349 = torch.constant.int 8 + %int32_2350 = torch.constant.int 32 + %int128_2351 = torch.constant.int 128 + %2127 = torch.prim.ListConstruct %297, %int32_2347, %int2_2348, %int8_2349, %int32_2350, %int128_2351 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2128 = torch.aten.view %2126, %2127 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2128, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_2352 = torch.constant.int 8 + %int32_2353 = torch.constant.int 32 + %int128_2354 = torch.constant.int 128 + %2129 = torch.prim.ListConstruct %497, %int8_2352, %int32_2353, %int128_2354 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2130 = torch.aten.view %2128, %2129 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2130, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_2355 = torch.constant.int 32 + %2131 = torch.aten.mul.Scalar %arg2, %int32_2355 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2131, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int6_2356 = torch.constant.int 6 + %int1_2357 = torch.constant.int 1 + %2132 = torch.aten.add.Scalar %2131, %int6_2356, %int1_2357 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2132, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_2358 = torch.constant.int 2 + %2133 = torch.aten.mul.Scalar %2132, %int2_2358 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2133, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_2359 = torch.constant.int 1 %int1_2360 = torch.constant.int 1 - %2145 = torch.aten.unsqueeze %2136, %int1_2360 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_2361 = torch.constant.int 0 - %2146 = torch.aten.unsqueeze %2144, %int0_2361 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %2147 = torch.aten.mul.Tensor %2145, %2146 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_2362 = torch.constant.int 1 - %2148 = torch.aten.size.int %2094, %int1_2362 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_2363 = torch.constant.int 0 - %2149 = torch.aten.add.int %int0_2363, %2148 : !torch.int, !torch.int -> !torch.int - %int0_2364 = torch.constant.int 0 - %int0_2365 = torch.constant.int 0 - %int1_2366 = torch.constant.int 1 - %2150 = torch.aten.slice.Tensor %2147, %int0_2364, %int0_2365, %2149, %int1_2366 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2150, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_2367 = torch.constant.int 1 - %int0_2368 = torch.constant.int 0 - %int9223372036854775807_2369 = torch.constant.int 9223372036854775807 - %int1_2370 = torch.constant.int 1 - %2151 = torch.aten.slice.Tensor %2150, %int1_2367, %int0_2368, %int9223372036854775807_2369, %int1_2370 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2151, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_2371 = torch.constant.int 1 - %int0_2372 = torch.constant.int 0 - %int9223372036854775807_2373 = torch.constant.int 9223372036854775807 - %int1_2374 = torch.constant.int 1 - %2152 = torch.aten.slice.Tensor %2151, %int1_2371, %int0_2372, %int9223372036854775807_2373, %int1_2374 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2152, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_2375 = torch.constant.int 0 - %2153 = torch.aten.unsqueeze %2152, %int0_2375 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2153, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_2376 = torch.constant.int 1 - %int0_2377 = torch.constant.int 0 - %int9223372036854775807_2378 = torch.constant.int 9223372036854775807 - %int1_2379 = torch.constant.int 1 - %2154 = torch.aten.slice.Tensor %2153, %int1_2376, %int0_2377, %int9223372036854775807_2378, %int1_2379 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2154, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_2380 = torch.constant.int 2 - %int0_2381 = torch.constant.int 0 - %int9223372036854775807_2382 = torch.constant.int 9223372036854775807 - %int1_2383 = torch.constant.int 1 - %2155 = torch.aten.slice.Tensor %2154, %int2_2380, %int0_2381, %int9223372036854775807_2382, %int1_2383 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2155, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_2384 = torch.constant.int 4 - %int1_2385 = torch.constant.int 1 - %int1_2386 = torch.constant.int 1 - %2156 = torch.prim.ListConstruct %int4_2384, %int1_2385, %int1_2386 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2157 = torch.aten.repeat %2155, %2156 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %2157, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_2387 = torch.constant.int 6 - %2158 = torch.prims.convert_element_type %2105, %int6_2387 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %2158, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %2159 = torch_c.to_builtin_tensor %2158 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %2160 = torch_c.to_builtin_tensor %2157 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %2161 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%2159, %2160) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %2162 = torch_c.from_builtin_tensor %2161 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %2162, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_2388 = torch.constant.int 5 - %2163 = torch.prims.convert_element_type %2162, %int5_2388 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2163, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_2389 = torch.constant.int 64 - %2164 = torch.aten.mul.Scalar %arg2, %int64_2389 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2164, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int18 = torch.constant.int 18 - %int1_2390 = torch.constant.int 1 - %2165 = torch.aten.add.Scalar %2164, %int18, %int1_2390 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2165, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %2134 = torch.aten.add.Scalar %2133, %int1_2359, %int1_2360 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2134, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %2135 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %2136 = torch.aten.view %2134, %2135 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %2136, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_2361 = torch.constant.int 4 + %int32_2362 = torch.constant.int 32 + %int8_2363 = torch.constant.int 8 + %int128_2364 = torch.constant.int 128 + %2137 = torch.prim.ListConstruct %int4_2361, %296, %int32_2362, %int8_2363, %int128_2364 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2138 = torch.aten.view %1978, %2137 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2138, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_2365 = torch.constant.int 32 + %int8_2366 = torch.constant.int 8 + %int128_2367 = torch.constant.int 128 + %2139 = torch.prim.ListConstruct %504, %int32_2365, %int8_2366, %int128_2367 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2140 = torch.aten.view %2138, %2139 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %2140, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_2368 = torch.constant.int 1 + %int2_2369 = torch.constant.int 2 + %2141 = torch.aten.transpose.int %2140, %int1_2368, %int2_2369 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2141, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_2370 = torch.constant.int 5 + %2142 = torch.prims.convert_element_type %2141, %int5_2370 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2142, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %2143 = torch.prim.ListConstruct %2136 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_2371 = torch.constant.bool false + %2144 = torch.aten.index_put %2130, %2143, %2142, %false_2371 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2144, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_2372 = torch.constant.int 32 + %int2_2373 = torch.constant.int 2 + %int8_2374 = torch.constant.int 8 + %int32_2375 = torch.constant.int 32 + %int128_2376 = torch.constant.int 128 + %2145 = torch.prim.ListConstruct %297, %int32_2372, %int2_2373, %int8_2374, %int32_2375, %int128_2376 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2146 = torch.aten.view %2144, %2145 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2146, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_2377 = torch.constant.int 2097152 + %2147 = torch.prim.ListConstruct %297, %int2097152_2377 : (!torch.int, !torch.int) -> !torch.list + %2148 = torch.aten.view %2146, %2147 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2148, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_2378 = torch.constant.int -2 + %2149 = torch.aten.unsqueeze %2104, %int-2_2378 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2149, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_2379 = torch.constant.int 4 + %int8_2380 = torch.constant.int 8 + %int4_2381 = torch.constant.int 4 + %int128_2382 = torch.constant.int 128 + %2150 = torch.prim.ListConstruct %int4_2379, %298, %int8_2380, %int4_2381, %int128_2382 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_2383 = torch.constant.bool false + %2151 = torch.aten.expand %2149, %2150, %false_2383 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2151, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_2384 = torch.constant.int 0 + %2152 = torch.aten.clone %2151, %int0_2384 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2152, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_2385 = torch.constant.int 4 + %int32_2386 = torch.constant.int 32 + %int128_2387 = torch.constant.int 128 + %2153 = torch.prim.ListConstruct %int4_2385, %298, %int32_2386, %int128_2387 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2154 = torch.aten._unsafe_view %2152, %2153 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2154, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_2388 = torch.constant.int -2 + %2155 = torch.aten.unsqueeze %1978, %int-2_2388 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2155, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_2389 = torch.constant.int 4 + %int8_2390 = torch.constant.int 8 %int4_2391 = torch.constant.int 4 - %int32_2392 = torch.constant.int 32 - %int8_2393 = torch.constant.int 8 - %int128_2394 = torch.constant.int 128 - %2166 = torch.prim.ListConstruct %int4_2391, %398, %int32_2392, %int8_2393, %int128_2394 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2167 = torch.aten.view %2163, %2166 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2167, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int128_2392 = torch.constant.int 128 + %2156 = torch.prim.ListConstruct %int4_2389, %298, %int8_2390, %int4_2391, %int128_2392 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_2393 = torch.constant.bool false + %2157 = torch.aten.expand %2155, %2156, %false_2393 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2157, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_2394 = torch.constant.int 0 + %2158 = torch.aten.clone %2157, %int0_2394 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2158, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_2395 = torch.constant.int 4 - %2168 = torch.aten.mul.int %int4_2395, %398 : !torch.int, !torch.int -> !torch.int %int32_2396 = torch.constant.int 32 - %int8_2397 = torch.constant.int 8 - %int128_2398 = torch.constant.int 128 - %2169 = torch.prim.ListConstruct %2168, %int32_2396, %int8_2397, %int128_2398 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2170 = torch.aten.view %2167, %2169 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2170, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_2399 = torch.constant.int 4 - %2171 = torch.aten.mul.int %int4_2399, %398 : !torch.int, !torch.int -> !torch.int - %2172 = torch.prim.ListConstruct %2171 : (!torch.int) -> !torch.list - %2173 = torch.aten.view %2165, %2172 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2173, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_2400 = torch.constant.int 32 + %int128_2397 = torch.constant.int 128 + %2159 = torch.prim.ListConstruct %int4_2395, %298, %int32_2396, %int128_2397 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2160 = torch.aten._unsafe_view %2158, %2159 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2160, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_2398 = torch.constant.int 1 + %int2_2399 = torch.constant.int 2 + %2161 = torch.aten.transpose.int %2041, %int1_2398, %int2_2399 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2161, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_2400 = torch.constant.int 1 %int2_2401 = torch.constant.int 2 - %int32_2402 = torch.constant.int 32 - %int8_2403 = torch.constant.int 8 - %int128_2404 = torch.constant.int 128 - %2174 = torch.prim.ListConstruct %389, %int32_2400, %int2_2401, %int32_2402, %int8_2403, %int128_2404 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2175 = torch.aten.view %2007, %2174 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2175, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2405 = torch.constant.int 32 - %2176 = torch.aten.mul.int %389, %int32_2405 : !torch.int, !torch.int -> !torch.int - %int2_2406 = torch.constant.int 2 - %2177 = torch.aten.mul.int %2176, %int2_2406 : !torch.int, !torch.int -> !torch.int - %int32_2407 = torch.constant.int 32 - %int8_2408 = torch.constant.int 8 - %int128_2409 = torch.constant.int 128 - %2178 = torch.prim.ListConstruct %2177, %int32_2407, %int8_2408, %int128_2409 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2179 = torch.aten.view %2175, %2178 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2179, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %2180 = torch.prim.ListConstruct %2173 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_2410 = torch.constant.bool false - %2181 = torch.aten.index_put %2179, %2180, %2170, %false_2410 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2181, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_2411 = torch.constant.int 32 - %int2_2412 = torch.constant.int 2 - %int32_2413 = torch.constant.int 32 - %int8_2414 = torch.constant.int 8 - %int128_2415 = torch.constant.int 128 - %2182 = torch.prim.ListConstruct %389, %int32_2411, %int2_2412, %int32_2413, %int8_2414, %int128_2415 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2183 = torch.aten.view %2181, %2182 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2183, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2416 = torch.constant.int 2097152 - %2184 = torch.prim.ListConstruct %389, %int2097152_2416 : (!torch.int, !torch.int) -> !torch.list - %2185 = torch.aten.view %2183, %2184 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2185, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_2417 = torch.constant.int 32 - %int2_2418 = torch.constant.int 2 - %int32_2419 = torch.constant.int 32 - %int8_2420 = torch.constant.int 8 - %int128_2421 = torch.constant.int 128 - %2186 = torch.prim.ListConstruct %389, %int32_2417, %int2_2418, %int32_2419, %int8_2420, %int128_2421 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2187 = torch.aten.view %2185, %2186 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2187, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2422 = torch.constant.int 32 - %int8_2423 = torch.constant.int 8 - %int128_2424 = torch.constant.int 128 - %2188 = torch.prim.ListConstruct %2177, %int32_2422, %int8_2423, %int128_2424 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2189 = torch.aten.view %2187, %2188 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2189, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_2425 = torch.constant.int 4 - %int32_2426 = torch.constant.int 32 - %int8_2427 = torch.constant.int 8 - %int128_2428 = torch.constant.int 128 - %2190 = torch.prim.ListConstruct %int4_2425, %398, %int32_2426, %int8_2427, %int128_2428 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2191 = torch.aten.view %2107, %2190 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2191, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_2429 = torch.constant.int 4 - %2192 = torch.aten.mul.int %int4_2429, %398 : !torch.int, !torch.int -> !torch.int - %int32_2430 = torch.constant.int 32 - %int8_2431 = torch.constant.int 8 - %int128_2432 = torch.constant.int 128 - %2193 = torch.prim.ListConstruct %2192, %int32_2430, %int8_2431, %int128_2432 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2194 = torch.aten.view %2191, %2193 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2194, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_2433 = torch.constant.int 1 - %int1_2434 = torch.constant.int 1 - %2195 = torch.aten.add.Scalar %2165, %int1_2433, %int1_2434 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2195, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_2435 = torch.constant.int 4 - %2196 = torch.aten.mul.int %int4_2435, %398 : !torch.int, !torch.int -> !torch.int - %2197 = torch.prim.ListConstruct %2196 : (!torch.int) -> !torch.list - %2198 = torch.aten.view %2195, %2197 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2198, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %2199 = torch.prim.ListConstruct %2198 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_2436 = torch.constant.bool false - %2200 = torch.aten.index_put %2189, %2199, %2194, %false_2436 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2200, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_2437 = torch.constant.int 32 - %int2_2438 = torch.constant.int 2 - %int32_2439 = torch.constant.int 32 - %int8_2440 = torch.constant.int 8 - %int128_2441 = torch.constant.int 128 - %2201 = torch.prim.ListConstruct %389, %int32_2437, %int2_2438, %int32_2439, %int8_2440, %int128_2441 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2202 = torch.aten.view %2200, %2201 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2202, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2442 = torch.constant.int 2097152 - %2203 = torch.prim.ListConstruct %389, %int2097152_2442 : (!torch.int, !torch.int) -> !torch.list - %2204 = torch.aten.view %2202, %2203 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2204, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_2443 = torch.constant.int -2 - %2205 = torch.aten.unsqueeze %2163, %int-2_2443 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2205, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_2444 = torch.constant.int 4 - %int8_2445 = torch.constant.int 8 - %int4_2446 = torch.constant.int 4 - %int128_2447 = torch.constant.int 128 - %2206 = torch.prim.ListConstruct %int4_2444, %2148, %int8_2445, %int4_2446, %int128_2447 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2448 = torch.constant.bool false - %2207 = torch.aten.expand %2205, %2206, %false_2448 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2207, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_2449 = torch.constant.int 0 - %2208 = torch.aten.clone %2207, %int0_2449 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2208, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_2450 = torch.constant.int 4 - %int32_2451 = torch.constant.int 32 - %int128_2452 = torch.constant.int 128 - %2209 = torch.prim.ListConstruct %int4_2450, %2148, %int32_2451, %int128_2452 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2210 = torch.aten._unsafe_view %2208, %2209 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2210, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_2453 = torch.constant.int -2 - %2211 = torch.aten.unsqueeze %2107, %int-2_2453 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2211, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_2454 = torch.constant.int 1 - %2212 = torch.aten.size.int %2101, %int1_2454 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_2455 = torch.constant.int 4 - %int8_2456 = torch.constant.int 8 - %int4_2457 = torch.constant.int 4 - %int128_2458 = torch.constant.int 128 - %2213 = torch.prim.ListConstruct %int4_2455, %2212, %int8_2456, %int4_2457, %int128_2458 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2459 = torch.constant.bool false - %2214 = torch.aten.expand %2211, %2213, %false_2459 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2214, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_2460 = torch.constant.int 0 - %2215 = torch.aten.clone %2214, %int0_2460 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2215, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_2461 = torch.constant.int 4 - %int32_2462 = torch.constant.int 32 - %int128_2463 = torch.constant.int 128 - %2216 = torch.prim.ListConstruct %int4_2461, %2212, %int32_2462, %int128_2463 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2217 = torch.aten._unsafe_view %2215, %2216 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2217, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_2464 = torch.constant.int 1 - %int2_2465 = torch.constant.int 2 - %2218 = torch.aten.transpose.int %2135, %int1_2464, %int2_2465 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2218, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_2466 = torch.constant.int 1 - %int2_2467 = torch.constant.int 2 - %2219 = torch.aten.transpose.int %2210, %int1_2466, %int2_2467 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2219, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_2468 = torch.constant.int 1 - %int2_2469 = torch.constant.int 2 - %2220 = torch.aten.transpose.int %2217, %int1_2468, %int2_2469 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2220, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_2470 = torch.constant.float 0.000000e+00 - %true_2471 = torch.constant.bool true - %none_2472 = torch.constant.none - %none_2473 = torch.constant.none - %2221:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2218, %2219, %2220, %float0.000000e00_2470, %true_2471, %none_2472, %none_2473) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %2221#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_2474 = torch.constant.int 1 - %int2_2475 = torch.constant.int 2 - %2222 = torch.aten.transpose.int %2221#0, %int1_2474, %int2_2475 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2222, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %2162 = torch.aten.transpose.int %2154, %int1_2400, %int2_2401 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2162, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_2402 = torch.constant.int 1 + %int2_2403 = torch.constant.int 2 + %2163 = torch.aten.transpose.int %2160, %int1_2402, %int2_2403 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2163, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_2404 = torch.constant.float 0.000000e+00 + %false_2405 = torch.constant.bool false + %none_2406 = torch.constant.none + %2164:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2161, %2162, %2163, %float0.000000e00_2404, %false_2405, %327, %none_2406) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %2164#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_2407 = torch.constant.int 1 + %int2_2408 = torch.constant.int 2 + %2165 = torch.aten.transpose.int %2164#0, %int1_2407, %int2_2408 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2165, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_2409 = torch.constant.int 4 + %int4096_2410 = torch.constant.int 4096 + %2166 = torch.prim.ListConstruct %int4_2409, %298, %int4096_2410 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2167 = torch.aten.view %2165, %2166 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2167, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_2411 = torch.constant.int -2 + %int-1_2412 = torch.constant.int -1 + %2168 = torch.aten.transpose.int %60, %int-2_2411, %int-1_2412 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2413 = torch.constant.int 5 + %2169 = torch.prims.convert_element_type %2168, %int5_2413 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_2414 = torch.constant.int 4096 + %2170 = torch.prim.ListConstruct %342, %int4096_2414 : (!torch.int, !torch.int) -> !torch.list + %2171 = torch.aten.view %2167, %2170 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2171, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2172 = torch.aten.mm %2171, %2169 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2172, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_2415 = torch.constant.int 4 + %int4096_2416 = torch.constant.int 4096 + %2173 = torch.prim.ListConstruct %int4_2415, %298, %int4096_2416 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2174 = torch.aten.view %2172, %2173 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2174, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_2417 = torch.constant.int 1 + %2175 = torch.aten.add.Tensor %1941, %2174, %int1_2417 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2175, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_2418 = torch.constant.int 6 + %2176 = torch.prims.convert_element_type %2175, %int6_2418 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2176, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_2419 = torch.constant.int 2 + %2177 = torch.aten.pow.Tensor_Scalar %2176, %int2_2419 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2177, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_2420 = torch.constant.int -1 + %2178 = torch.prim.ListConstruct %int-1_2420 : (!torch.int) -> !torch.list + %true_2421 = torch.constant.bool true + %none_2422 = torch.constant.none + %2179 = torch.aten.mean.dim %2177, %2178, %true_2421, %none_2422 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2179, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_2423 = torch.constant.float 9.9999997473787516E-6 + %int1_2424 = torch.constant.int 1 + %2180 = torch.aten.add.Scalar %2179, %float9.999990e-06_2423, %int1_2424 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2180, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %2181 = torch.aten.rsqrt %2180 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2181, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %2182 = torch.aten.mul.Tensor %2176, %2181 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2182, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_2425 = torch.constant.int 5 + %2183 = torch.prims.convert_element_type %2182, %int5_2425 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2183, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %2184 = torch.aten.mul.Tensor %61, %2183 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2184, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_2426 = torch.constant.int 5 + %2185 = torch.prims.convert_element_type %2184, %int5_2426 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2185, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_2427 = torch.constant.int -2 + %int-1_2428 = torch.constant.int -1 + %2186 = torch.aten.transpose.int %62, %int-2_2427, %int-1_2428 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_2429 = torch.constant.int 5 + %2187 = torch.prims.convert_element_type %2186, %int5_2429 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_2430 = torch.constant.int 4096 + %2188 = torch.prim.ListConstruct %342, %int4096_2430 : (!torch.int, !torch.int) -> !torch.list + %2189 = torch.aten.view %2185, %2188 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2189, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2190 = torch.aten.mm %2189, %2187 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %2190, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_2431 = torch.constant.int 4 + %int14336_2432 = torch.constant.int 14336 + %2191 = torch.prim.ListConstruct %int4_2431, %298, %int14336_2432 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2192 = torch.aten.view %2190, %2191 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %2192, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %2193 = torch.aten.silu %2192 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %2193, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_2433 = torch.constant.int -2 + %int-1_2434 = torch.constant.int -1 + %2194 = torch.aten.transpose.int %63, %int-2_2433, %int-1_2434 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_2435 = torch.constant.int 5 + %2195 = torch.prims.convert_element_type %2194, %int5_2435 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_2436 = torch.constant.int 4096 + %2196 = torch.prim.ListConstruct %342, %int4096_2436 : (!torch.int, !torch.int) -> !torch.list + %2197 = torch.aten.view %2185, %2196 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2197, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2198 = torch.aten.mm %2197, %2195 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %2198, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_2437 = torch.constant.int 4 + %int14336_2438 = torch.constant.int 14336 + %2199 = torch.prim.ListConstruct %int4_2437, %298, %int14336_2438 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2200 = torch.aten.view %2198, %2199 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %2200, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %2201 = torch.aten.mul.Tensor %2193, %2200 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %2201, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_2439 = torch.constant.int -2 + %int-1_2440 = torch.constant.int -1 + %2202 = torch.aten.transpose.int %64, %int-2_2439, %int-1_2440 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_2441 = torch.constant.int 5 + %2203 = torch.prims.convert_element_type %2202, %int5_2441 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_2442 = torch.constant.int 14336 + %2204 = torch.prim.ListConstruct %342, %int14336_2442 : (!torch.int, !torch.int) -> !torch.list + %2205 = torch.aten.view %2201, %2204 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %2205, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %2206 = torch.aten.mm %2205, %2203 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2206, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_2443 = torch.constant.int 4 + %int4096_2444 = torch.constant.int 4096 + %2207 = torch.prim.ListConstruct %int4_2443, %298, %int4096_2444 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2208 = torch.aten.view %2206, %2207 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2208, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_2445 = torch.constant.int 1 + %2209 = torch.aten.add.Tensor %2175, %2208, %int1_2445 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2209, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_2446 = torch.constant.int 6 + %2210 = torch.prims.convert_element_type %2209, %int6_2446 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2210, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_2447 = torch.constant.int 2 + %2211 = torch.aten.pow.Tensor_Scalar %2210, %int2_2447 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2211, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_2448 = torch.constant.int -1 + %2212 = torch.prim.ListConstruct %int-1_2448 : (!torch.int) -> !torch.list + %true_2449 = torch.constant.bool true + %none_2450 = torch.constant.none + %2213 = torch.aten.mean.dim %2211, %2212, %true_2449, %none_2450 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2213, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_2451 = torch.constant.float 9.9999997473787516E-6 + %int1_2452 = torch.constant.int 1 + %2214 = torch.aten.add.Scalar %2213, %float9.999990e-06_2451, %int1_2452 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2214, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %2215 = torch.aten.rsqrt %2214 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2215, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %2216 = torch.aten.mul.Tensor %2210, %2215 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2216, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_2453 = torch.constant.int 5 + %2217 = torch.prims.convert_element_type %2216, %int5_2453 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2217, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %2218 = torch.aten.mul.Tensor %65, %2217 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2218, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_2454 = torch.constant.int 5 + %2219 = torch.prims.convert_element_type %2218, %int5_2454 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2219, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_2455 = torch.constant.int -2 + %int-1_2456 = torch.constant.int -1 + %2220 = torch.aten.transpose.int %66, %int-2_2455, %int-1_2456 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2457 = torch.constant.int 5 + %2221 = torch.prims.convert_element_type %2220, %int5_2457 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_2458 = torch.constant.int 4096 + %2222 = torch.prim.ListConstruct %342, %int4096_2458 : (!torch.int, !torch.int) -> !torch.list + %2223 = torch.aten.view %2219, %2222 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2223, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2224 = torch.aten.mm %2223, %2221 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2224, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_2459 = torch.constant.int 4 + %int4096_2460 = torch.constant.int 4096 + %2225 = torch.prim.ListConstruct %int4_2459, %298, %int4096_2460 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2226 = torch.aten.view %2224, %2225 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2226, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_2461 = torch.constant.int -2 + %int-1_2462 = torch.constant.int -1 + %2227 = torch.aten.transpose.int %67, %int-2_2461, %int-1_2462 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_2463 = torch.constant.int 5 + %2228 = torch.prims.convert_element_type %2227, %int5_2463 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_2464 = torch.constant.int 4096 + %2229 = torch.prim.ListConstruct %342, %int4096_2464 : (!torch.int, !torch.int) -> !torch.list + %2230 = torch.aten.view %2219, %2229 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2230, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2231 = torch.aten.mm %2230, %2228 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %2231, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_2465 = torch.constant.int 4 + %int1024_2466 = torch.constant.int 1024 + %2232 = torch.prim.ListConstruct %int4_2465, %298, %int1024_2466 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2233 = torch.aten.view %2231, %2232 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %2233, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_2467 = torch.constant.int -2 + %int-1_2468 = torch.constant.int -1 + %2234 = torch.aten.transpose.int %68, %int-2_2467, %int-1_2468 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_2469 = torch.constant.int 5 + %2235 = torch.prims.convert_element_type %2234, %int5_2469 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_2470 = torch.constant.int 4096 + %2236 = torch.prim.ListConstruct %342, %int4096_2470 : (!torch.int, !torch.int) -> !torch.list + %2237 = torch.aten.view %2219, %2236 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2237, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2238 = torch.aten.mm %2237, %2235 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %2238, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_2471 = torch.constant.int 4 + %int1024_2472 = torch.constant.int 1024 + %2239 = torch.prim.ListConstruct %int4_2471, %298, %int1024_2472 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2240 = torch.aten.view %2238, %2239 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %2240, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_2473 = torch.constant.int 4 + %int32_2474 = torch.constant.int 32 + %int128_2475 = torch.constant.int 128 + %2241 = torch.prim.ListConstruct %int4_2473, %298, %int32_2474, %int128_2475 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2242 = torch.aten.view %2226, %2241 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2242, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int4_2476 = torch.constant.int 4 - %int4096_2477 = torch.constant.int 4096 - %2223 = torch.prim.ListConstruct %int4_2476, %2120, %int4096_2477 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2224 = torch.aten.view %2222, %2223 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2224, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_2478 = torch.constant.int -2 - %int-1_2479 = torch.constant.int -1 - %2225 = torch.aten.transpose.int %86, %int-2_2478, %int-1_2479 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_2480 = torch.constant.int 4 - %2226 = torch.aten.mul.int %int4_2480, %2120 : !torch.int, !torch.int -> !torch.int - %int4096_2481 = torch.constant.int 4096 - %2227 = torch.prim.ListConstruct %2226, %int4096_2481 : (!torch.int, !torch.int) -> !torch.list - %2228 = torch.aten.view %2224, %2227 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2228, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2229 = torch.aten.mm %2228, %2225 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2229, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_2482 = torch.constant.int 4 - %int4096_2483 = torch.constant.int 4096 - %2230 = torch.prim.ListConstruct %int4_2482, %2120, %int4096_2483 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2231 = torch.aten.view %2229, %2230 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2231, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_2484 = torch.constant.int 1 - %2232 = torch.aten.add.Tensor %2070, %2231, %int1_2484 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2232, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_2485 = torch.constant.int 6 - %2233 = torch.prims.convert_element_type %2232, %int6_2485 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2233, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_2486 = torch.constant.int 2 - %2234 = torch.aten.pow.Tensor_Scalar %2233, %int2_2486 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2234, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_2487 = torch.constant.int -1 - %2235 = torch.prim.ListConstruct %int-1_2487 : (!torch.int) -> !torch.list - %true_2488 = torch.constant.bool true - %none_2489 = torch.constant.none - %2236 = torch.aten.mean.dim %2234, %2235, %true_2488, %none_2489 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2236, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_2490 = torch.constant.float 9.9999997473787516E-6 - %int1_2491 = torch.constant.int 1 - %2237 = torch.aten.add.Scalar %2236, %float9.999990e-06_2490, %int1_2491 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2237, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2238 = torch.aten.rsqrt %2237 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2238, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2239 = torch.aten.mul.Tensor %2233, %2238 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2239, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2492 = torch.constant.int 5 - %2240 = torch.prims.convert_element_type %2239, %int5_2492 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2240, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %2241 = torch.aten.mul.Tensor %87, %2240 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2241, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2493 = torch.constant.int 5 - %2242 = torch.prims.convert_element_type %2241, %int5_2493 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2242, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_2494 = torch.constant.int -2 - %int-1_2495 = torch.constant.int -1 - %2243 = torch.aten.transpose.int %88, %int-2_2494, %int-1_2495 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_2496 = torch.constant.int 4 - %2244 = torch.aten.mul.int %int4_2496, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2497 = torch.constant.int 4096 - %2245 = torch.prim.ListConstruct %2244, %int4096_2497 : (!torch.int, !torch.int) -> !torch.list - %2246 = torch.aten.view %2242, %2245 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2246, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2247 = torch.aten.mm %2246, %2243 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2247, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_2498 = torch.constant.int 4 - %int14336_2499 = torch.constant.int 14336 - %2248 = torch.prim.ListConstruct %int4_2498, %306, %int14336_2499 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2249 = torch.aten.view %2247, %2248 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2249, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %2250 = torch.aten.silu %2249 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2250, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_2500 = torch.constant.int -2 - %int-1_2501 = torch.constant.int -1 - %2251 = torch.aten.transpose.int %89, %int-2_2500, %int-1_2501 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_2502 = torch.constant.int 4 - %2252 = torch.aten.mul.int %int4_2502, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2503 = torch.constant.int 4096 - %2253 = torch.prim.ListConstruct %2252, %int4096_2503 : (!torch.int, !torch.int) -> !torch.list - %2254 = torch.aten.view %2242, %2253 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2254, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2255 = torch.aten.mm %2254, %2251 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2255, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_2504 = torch.constant.int 4 - %int14336_2505 = torch.constant.int 14336 - %2256 = torch.prim.ListConstruct %int4_2504, %306, %int14336_2505 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2257 = torch.aten.view %2255, %2256 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2257, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %2258 = torch.aten.mul.Tensor %2250, %2257 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2258, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_2506 = torch.constant.int -2 - %int-1_2507 = torch.constant.int -1 - %2259 = torch.aten.transpose.int %90, %int-2_2506, %int-1_2507 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int8_2477 = torch.constant.int 8 + %int128_2478 = torch.constant.int 128 + %2243 = torch.prim.ListConstruct %int4_2476, %298, %int8_2477, %int128_2478 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2244 = torch.aten.view %2233, %2243 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2244, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_2479 = torch.constant.int 4 + %int8_2480 = torch.constant.int 8 + %int128_2481 = torch.constant.int 128 + %2245 = torch.prim.ListConstruct %int4_2479, %298, %int8_2480, %int128_2481 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2246 = torch.aten.view %2240, %2245 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2246, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_2482 = torch.constant.int 131072 + %none_2483 = torch.constant.none + %none_2484 = torch.constant.none + %cpu_2485 = torch.constant.device "cpu" + %false_2486 = torch.constant.bool false + %2247 = torch.aten.arange %int131072_2482, %none_2483, %none_2484, %cpu_2485, %false_2486 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_2487 = torch.constant.int 0 + %int128_2488 = torch.constant.int 128 + %int2_2489 = torch.constant.int 2 + %int4_2490 = torch.constant.int 4 + %none_2491 = torch.constant.none + %cpu_2492 = torch.constant.device "cpu" + %false_2493 = torch.constant.bool false + %2248 = torch.aten.arange.start_step %int0_2487, %int128_2488, %int2_2489, %int4_2490, %none_2491, %cpu_2492, %false_2493 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_2494 = torch.constant.int 6 + %2249 = torch.prims.convert_element_type %2248, %int6_2494 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_2495 = torch.constant.int 128 + %2250 = torch.aten.div.Scalar %2249, %int128_2495 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_2496 = torch.constant.float 5.000000e+05 + %2251 = torch.aten.pow.Scalar %float5.000000e05_2496, %2250 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2252 = torch.aten.reciprocal %2251 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_2497 = torch.constant.float 1.000000e+00 + %2253 = torch.aten.mul.Scalar %2252, %float1.000000e00_2497 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %2254 = torch.aten.reciprocal %2253 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_2498 = torch.constant.float 6.2831853071795862 + %2255 = torch.aten.mul.Scalar %2254, %float6.283190e00_2498 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_2499 = torch.constant.float 8.192000e+03 + %2256 = torch.aten.gt.Scalar %2255, %float8.192000e03_2499 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_2500 = torch.constant.int 8 + %2257 = torch.aten.div.Scalar %2253, %int8_2500 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %2258 = torch.aten.where.self %2256, %2257, %2253 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2259 = torch.aten.reciprocal %2255 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_2501 = torch.constant.int 8192 + %2260 = torch.aten.mul.Scalar %2259, %int8192_2501 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_2502 = torch.constant.int 1 + %int1_2503 = torch.constant.int 1 + %2261 = torch.aten.sub.Scalar %2260, %int1_2502, %int1_2503 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_2504 = torch.constant.int 3 + %2262 = torch.aten.div.Scalar %2261, %int3_2504 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_2505 = torch.constant.int 1 + %int1_2506 = torch.constant.int 1 + %2263 = torch.aten.rsub.Scalar %2262, %int1_2505, %int1_2506 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %2264 = torch.aten.mul.Tensor %2263, %2258 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_2507 = torch.constant.int 8 + %2265 = torch.aten.div.Scalar %2264, %int8_2507 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %2266 = torch.aten.mul.Tensor %2262, %2258 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> %int1_2508 = torch.constant.int 1 - %2260 = torch.aten.size.int %2249, %int1_2508 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_2509 = torch.constant.int 4 - %2261 = torch.aten.mul.int %int4_2509, %2260 : !torch.int, !torch.int -> !torch.int - %int14336_2510 = torch.constant.int 14336 - %2262 = torch.prim.ListConstruct %2261, %int14336_2510 : (!torch.int, !torch.int) -> !torch.list - %2263 = torch.aten.view %2258, %2262 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2263, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %2264 = torch.aten.mm %2263, %2259 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2264, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_2511 = torch.constant.int 4 - %int4096_2512 = torch.constant.int 4096 - %2265 = torch.prim.ListConstruct %int4_2511, %2260, %int4096_2512 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2266 = torch.aten.view %2264, %2265 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2266, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %2267 = torch.aten.add.Tensor %2265, %2266, %int1_2508 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_2509 = torch.constant.float 2.048000e+03 + %2268 = torch.aten.lt.Scalar %2255, %float2.048000e03_2509 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2269 = torch.aten.bitwise_not %2268 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_2510 = torch.constant.float 8.192000e+03 + %2270 = torch.aten.gt.Scalar %2255, %float8.192000e03_2510 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2271 = torch.aten.bitwise_not %2270 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2272 = torch.aten.mul.Tensor %2269, %2271 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2273 = torch.aten.where.self %2272, %2267, %2258 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2274 = torch.prim.ListConstruct %2273, %2273 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_2511 = torch.constant.int -1 + %2275 = torch.aten.cat %2274, %int-1_2511 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_2512 = torch.constant.int 6 + %2276 = torch.prims.convert_element_type %2275, %int6_2512 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> %int1_2513 = torch.constant.int 1 - %2267 = torch.aten.add.Tensor %2232, %2266, %int1_2513 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2267, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %2277 = torch.aten.unsqueeze %2247, %int1_2513 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> %int6_2514 = torch.constant.int 6 - %2268 = torch.prims.convert_element_type %2267, %int6_2514 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2268, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_2515 = torch.constant.int 2 - %2269 = torch.aten.pow.Tensor_Scalar %2268, %int2_2515 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2269, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_2516 = torch.constant.int -1 - %2270 = torch.prim.ListConstruct %int-1_2516 : (!torch.int) -> !torch.list - %true_2517 = torch.constant.bool true - %none_2518 = torch.constant.none - %2271 = torch.aten.mean.dim %2269, %2270, %true_2517, %none_2518 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2271, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_2519 = torch.constant.float 9.9999997473787516E-6 - %int1_2520 = torch.constant.int 1 - %2272 = torch.aten.add.Scalar %2271, %float9.999990e-06_2519, %int1_2520 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2272, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2273 = torch.aten.rsqrt %2272 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2273, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2274 = torch.aten.mul.Tensor %2268, %2273 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2274, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2521 = torch.constant.int 5 - %2275 = torch.prims.convert_element_type %2274, %int5_2521 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2275, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %2276 = torch.aten.mul.Tensor %91, %2275 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2276, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2522 = torch.constant.int 5 - %2277 = torch.prims.convert_element_type %2276, %int5_2522 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2277, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_2523 = torch.constant.int -2 - %int-1_2524 = torch.constant.int -1 - %2278 = torch.aten.transpose.int %92, %int-2_2523, %int-1_2524 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_2525 = torch.constant.int 4 - %2279 = torch.aten.mul.int %int4_2525, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2526 = torch.constant.int 4096 - %2280 = torch.prim.ListConstruct %2279, %int4096_2526 : (!torch.int, !torch.int) -> !torch.list - %2281 = torch.aten.view %2277, %2280 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2281, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2282 = torch.aten.mm %2281, %2278 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2282, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_2527 = torch.constant.int 4 - %int4096_2528 = torch.constant.int 4096 - %2283 = torch.prim.ListConstruct %int4_2527, %306, %int4096_2528 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2284 = torch.aten.view %2282, %2283 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2284, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_2529 = torch.constant.int -2 - %int-1_2530 = torch.constant.int -1 - %2285 = torch.aten.transpose.int %93, %int-2_2529, %int-1_2530 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_2531 = torch.constant.int 4 - %2286 = torch.aten.mul.int %int4_2531, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2532 = torch.constant.int 4096 - %2287 = torch.prim.ListConstruct %2286, %int4096_2532 : (!torch.int, !torch.int) -> !torch.list - %2288 = torch.aten.view %2277, %2287 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2288, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2289 = torch.aten.mm %2288, %2285 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %2289, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_2533 = torch.constant.int 4 - %int1024_2534 = torch.constant.int 1024 - %2290 = torch.prim.ListConstruct %int4_2533, %306, %int1024_2534 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2291 = torch.aten.view %2289, %2290 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %2291, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_2535 = torch.constant.int -2 - %int-1_2536 = torch.constant.int -1 - %2292 = torch.aten.transpose.int %94, %int-2_2535, %int-1_2536 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_2537 = torch.constant.int 4 - %2293 = torch.aten.mul.int %int4_2537, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2538 = torch.constant.int 4096 - %2294 = torch.prim.ListConstruct %2293, %int4096_2538 : (!torch.int, !torch.int) -> !torch.list - %2295 = torch.aten.view %2277, %2294 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2295, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2296 = torch.aten.mm %2295, %2292 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %2296, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_2539 = torch.constant.int 4 - %int1024_2540 = torch.constant.int 1024 - %2297 = torch.prim.ListConstruct %int4_2539, %306, %int1024_2540 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2298 = torch.aten.view %2296, %2297 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %2298, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_2541 = torch.constant.int 4 - %int32_2542 = torch.constant.int 32 - %int128_2543 = torch.constant.int 128 - %2299 = torch.prim.ListConstruct %int4_2541, %306, %int32_2542, %int128_2543 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2300 = torch.aten.view %2284, %2299 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2300, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_2544 = torch.constant.int 4 - %int8_2545 = torch.constant.int 8 - %int128_2546 = torch.constant.int 128 - %2301 = torch.prim.ListConstruct %int4_2544, %306, %int8_2545, %int128_2546 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2302 = torch.aten.view %2291, %2301 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2302, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_2547 = torch.constant.int 4 - %int8_2548 = torch.constant.int 8 - %int128_2549 = torch.constant.int 128 - %2303 = torch.prim.ListConstruct %int4_2547, %306, %int8_2548, %int128_2549 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2304 = torch.aten.view %2298, %2303 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2304, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_2550 = torch.constant.int 131072 - %none_2551 = torch.constant.none - %none_2552 = torch.constant.none - %cpu_2553 = torch.constant.device "cpu" - %false_2554 = torch.constant.bool false - %2305 = torch.aten.arange %int131072_2550, %none_2551, %none_2552, %cpu_2553, %false_2554 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_2555 = torch.constant.int 0 - %int128_2556 = torch.constant.int 128 - %none_2557 = torch.constant.none - %none_2558 = torch.constant.none - %cpu_2559 = torch.constant.device "cpu" - %false_2560 = torch.constant.bool false - %2306 = torch.aten.arange.start %int0_2555, %int128_2556, %none_2557, %none_2558, %cpu_2559, %false_2560 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_2561 = torch.constant.int 2 - %2307 = torch.aten.floor_divide.Scalar %2306, %int2_2561 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_2562 = torch.constant.int 6 - %2308 = torch.prims.convert_element_type %2307, %int6_2562 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_2563 = torch.constant.int 128 - %2309 = torch.aten.div.Scalar %2308, %int128_2563 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_2564 = torch.constant.float 2.000000e+00 - %2310 = torch.aten.mul.Scalar %2309, %float2.000000e00_2564 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_2565 = torch.constant.float 5.000000e+05 - %2311 = torch.aten.pow.Scalar %float5.000000e05_2565, %2310 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %2312 = torch.aten.reciprocal %2311 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_2566 = torch.constant.float 1.000000e+00 - %2313 = torch.aten.mul.Scalar %2312, %float1.000000e00_2566 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_2567 = torch.constant.int 1 - %2314 = torch.aten.unsqueeze %2305, %int1_2567 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_2568 = torch.constant.int 0 - %2315 = torch.aten.unsqueeze %2313, %int0_2568 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %2316 = torch.aten.mul.Tensor %2314, %2315 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_2569 = torch.constant.int 1 - %2317 = torch.aten.size.int %2284, %int1_2569 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_2570 = torch.constant.int 0 - %2318 = torch.aten.add.int %int0_2570, %2317 : !torch.int, !torch.int -> !torch.int - %int0_2571 = torch.constant.int 0 - %int0_2572 = torch.constant.int 0 - %int1_2573 = torch.constant.int 1 - %2319 = torch.aten.slice.Tensor %2316, %int0_2571, %int0_2572, %2318, %int1_2573 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2319, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_2574 = torch.constant.int 1 - %int0_2575 = torch.constant.int 0 - %int9223372036854775807_2576 = torch.constant.int 9223372036854775807 - %int1_2577 = torch.constant.int 1 - %2320 = torch.aten.slice.Tensor %2319, %int1_2574, %int0_2575, %int9223372036854775807_2576, %int1_2577 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2320, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_2578 = torch.constant.int 1 - %int0_2579 = torch.constant.int 0 - %int9223372036854775807_2580 = torch.constant.int 9223372036854775807 - %int1_2581 = torch.constant.int 1 - %2321 = torch.aten.slice.Tensor %2320, %int1_2578, %int0_2579, %int9223372036854775807_2580, %int1_2581 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2321, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_2582 = torch.constant.int 0 - %2322 = torch.aten.unsqueeze %2321, %int0_2582 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2322, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_2583 = torch.constant.int 1 - %int0_2584 = torch.constant.int 0 - %int9223372036854775807_2585 = torch.constant.int 9223372036854775807 - %int1_2586 = torch.constant.int 1 - %2323 = torch.aten.slice.Tensor %2322, %int1_2583, %int0_2584, %int9223372036854775807_2585, %int1_2586 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2323, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_2587 = torch.constant.int 2 - %int0_2588 = torch.constant.int 0 - %int9223372036854775807_2589 = torch.constant.int 9223372036854775807 - %int1_2590 = torch.constant.int 1 - %2324 = torch.aten.slice.Tensor %2323, %int2_2587, %int0_2588, %int9223372036854775807_2589, %int1_2590 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2324, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_2591 = torch.constant.int 4 + %2278 = torch.prims.convert_element_type %2277, %int6_2514 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_2515 = torch.constant.int 0 + %2279 = torch.aten.unsqueeze %2276, %int0_2515 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_2516 = torch.constant.int 6 + %2280 = torch.prims.convert_element_type %2279, %int6_2516 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %2281 = torch.aten.mul.Tensor %2278, %2280 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %2282 = torch.aten.cos %2281 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_2517 = torch.constant.int 5 + %2283 = torch.prims.convert_element_type %2282, %int5_2517 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %2284 = torch.aten.sin %2281 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_2518 = torch.constant.int 5 + %2285 = torch.prims.convert_element_type %2284, %int5_2518 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_2519 = torch.constant.int 0 + %int0_2520 = torch.constant.int 0 + %int1_2521 = torch.constant.int 1 + %2286 = torch.aten.slice.Tensor %2283, %int0_2519, %int0_2520, %298, %int1_2521 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2286, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_2522 = torch.constant.int 1 + %int0_2523 = torch.constant.int 0 + %int9223372036854775807_2524 = torch.constant.int 9223372036854775807 + %int1_2525 = torch.constant.int 1 + %2287 = torch.aten.slice.Tensor %2286, %int1_2522, %int0_2523, %int9223372036854775807_2524, %int1_2525 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2287, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_2526 = torch.constant.int 0 + %int0_2527 = torch.constant.int 0 + %int1_2528 = torch.constant.int 1 + %2288 = torch.aten.slice.Tensor %2285, %int0_2526, %int0_2527, %298, %int1_2528 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2288, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_2529 = torch.constant.int 1 + %int0_2530 = torch.constant.int 0 + %int9223372036854775807_2531 = torch.constant.int 9223372036854775807 + %int1_2532 = torch.constant.int 1 + %2289 = torch.aten.slice.Tensor %2288, %int1_2529, %int0_2530, %int9223372036854775807_2531, %int1_2532 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2289, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_2533 = torch.constant.int 0 + %2290 = torch.aten.unsqueeze %2287, %int0_2533 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2290, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_2534 = torch.constant.int 1 + %int0_2535 = torch.constant.int 0 + %int9223372036854775807_2536 = torch.constant.int 9223372036854775807 + %int1_2537 = torch.constant.int 1 + %2291 = torch.aten.slice.Tensor %2290, %int1_2534, %int0_2535, %int9223372036854775807_2536, %int1_2537 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2291, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_2538 = torch.constant.int 2 + %2292 = torch.aten.unsqueeze %2291, %int2_2538 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2292, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_2539 = torch.constant.int 3 + %int0_2540 = torch.constant.int 0 + %int9223372036854775807_2541 = torch.constant.int 9223372036854775807 + %int1_2542 = torch.constant.int 1 + %2293 = torch.aten.slice.Tensor %2292, %int3_2539, %int0_2540, %int9223372036854775807_2541, %int1_2542 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2293, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_2543 = torch.constant.int 4 + %int1_2544 = torch.constant.int 1 + %int1_2545 = torch.constant.int 1 + %int1_2546 = torch.constant.int 1 + %2294 = torch.prim.ListConstruct %int4_2543, %int1_2544, %int1_2545, %int1_2546 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2295 = torch.aten.repeat %2293, %2294 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2295, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_2547 = torch.constant.int 0 + %2296 = torch.aten.unsqueeze %2289, %int0_2547 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2296, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_2548 = torch.constant.int 1 + %int0_2549 = torch.constant.int 0 + %int9223372036854775807_2550 = torch.constant.int 9223372036854775807 + %int1_2551 = torch.constant.int 1 + %2297 = torch.aten.slice.Tensor %2296, %int1_2548, %int0_2549, %int9223372036854775807_2550, %int1_2551 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2297, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_2552 = torch.constant.int 2 + %2298 = torch.aten.unsqueeze %2297, %int2_2552 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2298, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_2553 = torch.constant.int 3 + %int0_2554 = torch.constant.int 0 + %int9223372036854775807_2555 = torch.constant.int 9223372036854775807 + %int1_2556 = torch.constant.int 1 + %2299 = torch.aten.slice.Tensor %2298, %int3_2553, %int0_2554, %int9223372036854775807_2555, %int1_2556 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2299, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_2557 = torch.constant.int 4 + %int1_2558 = torch.constant.int 1 + %int1_2559 = torch.constant.int 1 + %int1_2560 = torch.constant.int 1 + %2300 = torch.prim.ListConstruct %int4_2557, %int1_2558, %int1_2559, %int1_2560 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2301 = torch.aten.repeat %2299, %2300 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2301, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %2302 = torch.aten.mul.Tensor %2242, %2295 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2302, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_2561 = torch.constant.int 3 + %int0_2562 = torch.constant.int 0 + %int64_2563 = torch.constant.int 64 + %int1_2564 = torch.constant.int 1 + %2303 = torch.aten.slice.Tensor %2242, %int3_2561, %int0_2562, %int64_2563, %int1_2564 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %2303, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_2565 = torch.constant.int 3 + %int64_2566 = torch.constant.int 64 + %int9223372036854775807_2567 = torch.constant.int 9223372036854775807 + %int1_2568 = torch.constant.int 1 + %2304 = torch.aten.slice.Tensor %2242, %int3_2565, %int64_2566, %int9223372036854775807_2567, %int1_2568 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %2304, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %2305 = torch.aten.neg %2304 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %2305, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %2306 = torch.prim.ListConstruct %2305, %2303 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_2569 = torch.constant.int -1 + %2307 = torch.aten.cat %2306, %int-1_2569 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2307, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %2308 = torch.aten.mul.Tensor %2307, %2301 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2308, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_2570 = torch.constant.int 1 + %2309 = torch.aten.add.Tensor %2302, %2308, %int1_2570 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2309, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_2571 = torch.constant.int 131072 + %none_2572 = torch.constant.none + %none_2573 = torch.constant.none + %cpu_2574 = torch.constant.device "cpu" + %false_2575 = torch.constant.bool false + %2310 = torch.aten.arange %int131072_2571, %none_2572, %none_2573, %cpu_2574, %false_2575 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_2576 = torch.constant.int 0 + %int128_2577 = torch.constant.int 128 + %int2_2578 = torch.constant.int 2 + %int4_2579 = torch.constant.int 4 + %none_2580 = torch.constant.none + %cpu_2581 = torch.constant.device "cpu" + %false_2582 = torch.constant.bool false + %2311 = torch.aten.arange.start_step %int0_2576, %int128_2577, %int2_2578, %int4_2579, %none_2580, %cpu_2581, %false_2582 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_2583 = torch.constant.int 6 + %2312 = torch.prims.convert_element_type %2311, %int6_2583 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_2584 = torch.constant.int 128 + %2313 = torch.aten.div.Scalar %2312, %int128_2584 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_2585 = torch.constant.float 5.000000e+05 + %2314 = torch.aten.pow.Scalar %float5.000000e05_2585, %2313 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2315 = torch.aten.reciprocal %2314 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_2586 = torch.constant.float 1.000000e+00 + %2316 = torch.aten.mul.Scalar %2315, %float1.000000e00_2586 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %2317 = torch.aten.reciprocal %2316 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_2587 = torch.constant.float 6.2831853071795862 + %2318 = torch.aten.mul.Scalar %2317, %float6.283190e00_2587 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_2588 = torch.constant.float 8.192000e+03 + %2319 = torch.aten.gt.Scalar %2318, %float8.192000e03_2588 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_2589 = torch.constant.int 8 + %2320 = torch.aten.div.Scalar %2316, %int8_2589 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %2321 = torch.aten.where.self %2319, %2320, %2316 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2322 = torch.aten.reciprocal %2318 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_2590 = torch.constant.int 8192 + %2323 = torch.aten.mul.Scalar %2322, %int8192_2590 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_2591 = torch.constant.int 1 %int1_2592 = torch.constant.int 1 - %int1_2593 = torch.constant.int 1 - %2325 = torch.prim.ListConstruct %int4_2591, %int1_2592, %int1_2593 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2326 = torch.aten.repeat %2324, %2325 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %2326, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_2594 = torch.constant.int 6 - %2327 = torch.prims.convert_element_type %2300, %int6_2594 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %2327, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %2328 = torch_c.to_builtin_tensor %2327 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %2329 = torch_c.to_builtin_tensor %2326 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %2330 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%2328, %2329) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %2331 = torch_c.from_builtin_tensor %2330 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %2331, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_2595 = torch.constant.int 5 - %2332 = torch.prims.convert_element_type %2331, %int5_2595 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2332, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_2596 = torch.constant.int 131072 - %none_2597 = torch.constant.none - %none_2598 = torch.constant.none - %cpu_2599 = torch.constant.device "cpu" - %false_2600 = torch.constant.bool false - %2333 = torch.aten.arange %int131072_2596, %none_2597, %none_2598, %cpu_2599, %false_2600 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_2601 = torch.constant.int 0 - %int128_2602 = torch.constant.int 128 - %none_2603 = torch.constant.none - %none_2604 = torch.constant.none - %cpu_2605 = torch.constant.device "cpu" - %false_2606 = torch.constant.bool false - %2334 = torch.aten.arange.start %int0_2601, %int128_2602, %none_2603, %none_2604, %cpu_2605, %false_2606 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_2607 = torch.constant.int 2 - %2335 = torch.aten.floor_divide.Scalar %2334, %int2_2607 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_2608 = torch.constant.int 6 - %2336 = torch.prims.convert_element_type %2335, %int6_2608 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_2609 = torch.constant.int 128 - %2337 = torch.aten.div.Scalar %2336, %int128_2609 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_2610 = torch.constant.float 2.000000e+00 - %2338 = torch.aten.mul.Scalar %2337, %float2.000000e00_2610 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_2611 = torch.constant.float 5.000000e+05 - %2339 = torch.aten.pow.Scalar %float5.000000e05_2611, %2338 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %2340 = torch.aten.reciprocal %2339 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_2612 = torch.constant.float 1.000000e+00 - %2341 = torch.aten.mul.Scalar %2340, %float1.000000e00_2612 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_2613 = torch.constant.int 1 - %2342 = torch.aten.unsqueeze %2333, %int1_2613 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_2614 = torch.constant.int 0 - %2343 = torch.aten.unsqueeze %2341, %int0_2614 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %2344 = torch.aten.mul.Tensor %2342, %2343 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_2615 = torch.constant.int 1 - %2345 = torch.aten.size.int %2291, %int1_2615 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int + %2324 = torch.aten.sub.Scalar %2323, %int1_2591, %int1_2592 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_2593 = torch.constant.int 3 + %2325 = torch.aten.div.Scalar %2324, %int3_2593 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_2594 = torch.constant.int 1 + %int1_2595 = torch.constant.int 1 + %2326 = torch.aten.rsub.Scalar %2325, %int1_2594, %int1_2595 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %2327 = torch.aten.mul.Tensor %2326, %2321 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_2596 = torch.constant.int 8 + %2328 = torch.aten.div.Scalar %2327, %int8_2596 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %2329 = torch.aten.mul.Tensor %2325, %2321 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_2597 = torch.constant.int 1 + %2330 = torch.aten.add.Tensor %2328, %2329, %int1_2597 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_2598 = torch.constant.float 2.048000e+03 + %2331 = torch.aten.lt.Scalar %2318, %float2.048000e03_2598 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2332 = torch.aten.bitwise_not %2331 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_2599 = torch.constant.float 8.192000e+03 + %2333 = torch.aten.gt.Scalar %2318, %float8.192000e03_2599 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2334 = torch.aten.bitwise_not %2333 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2335 = torch.aten.mul.Tensor %2332, %2334 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2336 = torch.aten.where.self %2335, %2330, %2321 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2337 = torch.prim.ListConstruct %2336, %2336 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_2600 = torch.constant.int -1 + %2338 = torch.aten.cat %2337, %int-1_2600 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_2601 = torch.constant.int 6 + %2339 = torch.prims.convert_element_type %2338, %int6_2601 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_2602 = torch.constant.int 1 + %2340 = torch.aten.unsqueeze %2310, %int1_2602 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_2603 = torch.constant.int 6 + %2341 = torch.prims.convert_element_type %2340, %int6_2603 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_2604 = torch.constant.int 0 + %2342 = torch.aten.unsqueeze %2339, %int0_2604 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_2605 = torch.constant.int 6 + %2343 = torch.prims.convert_element_type %2342, %int6_2605 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %2344 = torch.aten.mul.Tensor %2341, %2343 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %2345 = torch.aten.cos %2344 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_2606 = torch.constant.int 5 + %2346 = torch.prims.convert_element_type %2345, %int5_2606 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %2347 = torch.aten.sin %2344 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_2607 = torch.constant.int 5 + %2348 = torch.prims.convert_element_type %2347, %int5_2607 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_2608 = torch.constant.int 0 + %int0_2609 = torch.constant.int 0 + %int1_2610 = torch.constant.int 1 + %2349 = torch.aten.slice.Tensor %2346, %int0_2608, %int0_2609, %298, %int1_2610 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2349, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_2611 = torch.constant.int 1 + %int0_2612 = torch.constant.int 0 + %int9223372036854775807_2613 = torch.constant.int 9223372036854775807 + %int1_2614 = torch.constant.int 1 + %2350 = torch.aten.slice.Tensor %2349, %int1_2611, %int0_2612, %int9223372036854775807_2613, %int1_2614 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2350, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_2615 = torch.constant.int 0 %int0_2616 = torch.constant.int 0 - %2346 = torch.aten.add.int %int0_2616, %2345 : !torch.int, !torch.int -> !torch.int - %int0_2617 = torch.constant.int 0 - %int0_2618 = torch.constant.int 0 - %int1_2619 = torch.constant.int 1 - %2347 = torch.aten.slice.Tensor %2344, %int0_2617, %int0_2618, %2346, %int1_2619 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2347, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_2620 = torch.constant.int 1 - %int0_2621 = torch.constant.int 0 - %int9223372036854775807_2622 = torch.constant.int 9223372036854775807 + %int1_2617 = torch.constant.int 1 + %2351 = torch.aten.slice.Tensor %2348, %int0_2615, %int0_2616, %298, %int1_2617 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2351, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_2618 = torch.constant.int 1 + %int0_2619 = torch.constant.int 0 + %int9223372036854775807_2620 = torch.constant.int 9223372036854775807 + %int1_2621 = torch.constant.int 1 + %2352 = torch.aten.slice.Tensor %2351, %int1_2618, %int0_2619, %int9223372036854775807_2620, %int1_2621 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2352, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_2622 = torch.constant.int 0 + %2353 = torch.aten.unsqueeze %2350, %int0_2622 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2353, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int1_2623 = torch.constant.int 1 - %2348 = torch.aten.slice.Tensor %2347, %int1_2620, %int0_2621, %int9223372036854775807_2622, %int1_2623 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2348, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_2624 = torch.constant.int 1 - %int0_2625 = torch.constant.int 0 - %int9223372036854775807_2626 = torch.constant.int 9223372036854775807 - %int1_2627 = torch.constant.int 1 - %2349 = torch.aten.slice.Tensor %2348, %int1_2624, %int0_2625, %int9223372036854775807_2626, %int1_2627 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2349, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_2628 = torch.constant.int 0 - %2350 = torch.aten.unsqueeze %2349, %int0_2628 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2350, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_2629 = torch.constant.int 1 - %int0_2630 = torch.constant.int 0 - %int9223372036854775807_2631 = torch.constant.int 9223372036854775807 - %int1_2632 = torch.constant.int 1 - %2351 = torch.aten.slice.Tensor %2350, %int1_2629, %int0_2630, %int9223372036854775807_2631, %int1_2632 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2351, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_2633 = torch.constant.int 2 - %int0_2634 = torch.constant.int 0 - %int9223372036854775807_2635 = torch.constant.int 9223372036854775807 - %int1_2636 = torch.constant.int 1 - %2352 = torch.aten.slice.Tensor %2351, %int2_2633, %int0_2634, %int9223372036854775807_2635, %int1_2636 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2352, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_2637 = torch.constant.int 4 - %int1_2638 = torch.constant.int 1 - %int1_2639 = torch.constant.int 1 - %2353 = torch.prim.ListConstruct %int4_2637, %int1_2638, %int1_2639 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2354 = torch.aten.repeat %2352, %2353 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %2354, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_2640 = torch.constant.int 6 - %2355 = torch.prims.convert_element_type %2302, %int6_2640 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %2355, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %2356 = torch_c.to_builtin_tensor %2355 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %2357 = torch_c.to_builtin_tensor %2354 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %2358 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%2356, %2357) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %2359 = torch_c.from_builtin_tensor %2358 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %2359, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_2641 = torch.constant.int 5 - %2360 = torch.prims.convert_element_type %2359, %int5_2641 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2360, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_2642 = torch.constant.int 64 - %2361 = torch.aten.mul.Scalar %arg2, %int64_2642 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2361, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int20 = torch.constant.int 20 - %int1_2643 = torch.constant.int 1 - %2362 = torch.aten.add.Scalar %2361, %int20, %int1_2643 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2362, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_2644 = torch.constant.int 4 - %int32_2645 = torch.constant.int 32 - %int8_2646 = torch.constant.int 8 - %int128_2647 = torch.constant.int 128 - %2363 = torch.prim.ListConstruct %int4_2644, %398, %int32_2645, %int8_2646, %int128_2647 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2364 = torch.aten.view %2360, %2363 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2364, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_2648 = torch.constant.int 4 - %2365 = torch.aten.mul.int %int4_2648, %398 : !torch.int, !torch.int -> !torch.int - %int32_2649 = torch.constant.int 32 - %int8_2650 = torch.constant.int 8 - %int128_2651 = torch.constant.int 128 - %2366 = torch.prim.ListConstruct %2365, %int32_2649, %int8_2650, %int128_2651 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2367 = torch.aten.view %2364, %2366 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2367, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_2652 = torch.constant.int 4 - %2368 = torch.aten.mul.int %int4_2652, %398 : !torch.int, !torch.int -> !torch.int - %2369 = torch.prim.ListConstruct %2368 : (!torch.int) -> !torch.list - %2370 = torch.aten.view %2362, %2369 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2370, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_2653 = torch.constant.int 32 - %int2_2654 = torch.constant.int 2 - %int32_2655 = torch.constant.int 32 - %int8_2656 = torch.constant.int 8 - %int128_2657 = torch.constant.int 128 - %2371 = torch.prim.ListConstruct %389, %int32_2653, %int2_2654, %int32_2655, %int8_2656, %int128_2657 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2372 = torch.aten.view %2204, %2371 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2372, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2658 = torch.constant.int 32 - %2373 = torch.aten.mul.int %389, %int32_2658 : !torch.int, !torch.int -> !torch.int - %int2_2659 = torch.constant.int 2 - %2374 = torch.aten.mul.int %2373, %int2_2659 : !torch.int, !torch.int -> !torch.int + %int0_2624 = torch.constant.int 0 + %int9223372036854775807_2625 = torch.constant.int 9223372036854775807 + %int1_2626 = torch.constant.int 1 + %2354 = torch.aten.slice.Tensor %2353, %int1_2623, %int0_2624, %int9223372036854775807_2625, %int1_2626 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2354, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_2627 = torch.constant.int 2 + %2355 = torch.aten.unsqueeze %2354, %int2_2627 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2355, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_2628 = torch.constant.int 3 + %int0_2629 = torch.constant.int 0 + %int9223372036854775807_2630 = torch.constant.int 9223372036854775807 + %int1_2631 = torch.constant.int 1 + %2356 = torch.aten.slice.Tensor %2355, %int3_2628, %int0_2629, %int9223372036854775807_2630, %int1_2631 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2356, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_2632 = torch.constant.int 4 + %int1_2633 = torch.constant.int 1 + %int1_2634 = torch.constant.int 1 + %int1_2635 = torch.constant.int 1 + %2357 = torch.prim.ListConstruct %int4_2632, %int1_2633, %int1_2634, %int1_2635 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2358 = torch.aten.repeat %2356, %2357 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2358, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_2636 = torch.constant.int 0 + %2359 = torch.aten.unsqueeze %2352, %int0_2636 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2359, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_2637 = torch.constant.int 1 + %int0_2638 = torch.constant.int 0 + %int9223372036854775807_2639 = torch.constant.int 9223372036854775807 + %int1_2640 = torch.constant.int 1 + %2360 = torch.aten.slice.Tensor %2359, %int1_2637, %int0_2638, %int9223372036854775807_2639, %int1_2640 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2360, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_2641 = torch.constant.int 2 + %2361 = torch.aten.unsqueeze %2360, %int2_2641 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2361, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_2642 = torch.constant.int 3 + %int0_2643 = torch.constant.int 0 + %int9223372036854775807_2644 = torch.constant.int 9223372036854775807 + %int1_2645 = torch.constant.int 1 + %2362 = torch.aten.slice.Tensor %2361, %int3_2642, %int0_2643, %int9223372036854775807_2644, %int1_2645 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2362, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_2646 = torch.constant.int 4 + %int1_2647 = torch.constant.int 1 + %int1_2648 = torch.constant.int 1 + %int1_2649 = torch.constant.int 1 + %2363 = torch.prim.ListConstruct %int4_2646, %int1_2647, %int1_2648, %int1_2649 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2364 = torch.aten.repeat %2362, %2363 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2364, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %2365 = torch.aten.mul.Tensor %2244, %2358 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2365, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_2650 = torch.constant.int 3 + %int0_2651 = torch.constant.int 0 + %int64_2652 = torch.constant.int 64 + %int1_2653 = torch.constant.int 1 + %2366 = torch.aten.slice.Tensor %2244, %int3_2650, %int0_2651, %int64_2652, %int1_2653 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %2366, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_2654 = torch.constant.int 3 + %int64_2655 = torch.constant.int 64 + %int9223372036854775807_2656 = torch.constant.int 9223372036854775807 + %int1_2657 = torch.constant.int 1 + %2367 = torch.aten.slice.Tensor %2244, %int3_2654, %int64_2655, %int9223372036854775807_2656, %int1_2657 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %2367, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %2368 = torch.aten.neg %2367 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %2368, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %2369 = torch.prim.ListConstruct %2368, %2366 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_2658 = torch.constant.int -1 + %2370 = torch.aten.cat %2369, %int-1_2658 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2370, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %2371 = torch.aten.mul.Tensor %2370, %2364 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2371, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_2659 = torch.constant.int 1 + %2372 = torch.aten.add.Tensor %2365, %2371, %int1_2659 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2372, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> %int32_2660 = torch.constant.int 32 - %int8_2661 = torch.constant.int 8 - %int128_2662 = torch.constant.int 128 - %2375 = torch.prim.ListConstruct %2374, %int32_2660, %int8_2661, %int128_2662 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2376 = torch.aten.view %2372, %2375 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2376, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %2377 = torch.prim.ListConstruct %2370 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_2663 = torch.constant.bool false - %2378 = torch.aten.index_put %2376, %2377, %2367, %false_2663 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2378, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_2664 = torch.constant.int 32 - %int2_2665 = torch.constant.int 2 + %2373 = torch.aten.mul.Scalar %arg2, %int32_2660 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2373, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int7 = torch.constant.int 7 + %int1_2661 = torch.constant.int 1 + %2374 = torch.aten.add.Scalar %2373, %int7, %int1_2661 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2374, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_2662 = torch.constant.int 2 + %2375 = torch.aten.mul.Scalar %2374, %int2_2662 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2375, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_2663 = torch.constant.int 0 + %int1_2664 = torch.constant.int 1 + %2376 = torch.aten.add.Scalar %2375, %int0_2663, %int1_2664 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2376, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %2377 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %2378 = torch.aten.view %2376, %2377 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %2378, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_2665 = torch.constant.int 4 %int32_2666 = torch.constant.int 32 %int8_2667 = torch.constant.int 8 %int128_2668 = torch.constant.int 128 - %2379 = torch.prim.ListConstruct %389, %int32_2664, %int2_2665, %int32_2666, %int8_2667, %int128_2668 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2380 = torch.aten.view %2378, %2379 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2380, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2669 = torch.constant.int 2097152 - %2381 = torch.prim.ListConstruct %389, %int2097152_2669 : (!torch.int, !torch.int) -> !torch.list - %2382 = torch.aten.view %2380, %2381 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2382, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_2670 = torch.constant.int 32 - %int2_2671 = torch.constant.int 2 - %int32_2672 = torch.constant.int 32 - %int8_2673 = torch.constant.int 8 - %int128_2674 = torch.constant.int 128 - %2383 = torch.prim.ListConstruct %389, %int32_2670, %int2_2671, %int32_2672, %int8_2673, %int128_2674 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2384 = torch.aten.view %2382, %2383 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2384, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> + %2379 = torch.prim.ListConstruct %int4_2665, %296, %int32_2666, %int8_2667, %int128_2668 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2380 = torch.aten.view %2372, %2379 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2380, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_2669 = torch.constant.int 32 + %int8_2670 = torch.constant.int 8 + %int128_2671 = torch.constant.int 128 + %2381 = torch.prim.ListConstruct %504, %int32_2669, %int8_2670, %int128_2671 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2382 = torch.aten.view %2380, %2381 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %2382, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_2672 = torch.constant.int 1 + %int2_2673 = torch.constant.int 2 + %2383 = torch.aten.transpose.int %2382, %int1_2672, %int2_2673 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2383, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_2674 = torch.constant.int 5 + %2384 = torch.prims.convert_element_type %2383, %int5_2674 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2384, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> %int32_2675 = torch.constant.int 32 - %int8_2676 = torch.constant.int 8 - %int128_2677 = torch.constant.int 128 - %2385 = torch.prim.ListConstruct %2374, %int32_2675, %int8_2676, %int128_2677 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2386 = torch.aten.view %2384, %2385 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2386, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_2678 = torch.constant.int 4 - %int32_2679 = torch.constant.int 32 + %int2_2676 = torch.constant.int 2 + %int8_2677 = torch.constant.int 8 + %int32_2678 = torch.constant.int 32 + %int128_2679 = torch.constant.int 128 + %2385 = torch.prim.ListConstruct %297, %int32_2675, %int2_2676, %int8_2677, %int32_2678, %int128_2679 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2386 = torch.aten.view %2148, %2385 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2386, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> %int8_2680 = torch.constant.int 8 - %int128_2681 = torch.constant.int 128 - %2387 = torch.prim.ListConstruct %int4_2678, %398, %int32_2679, %int8_2680, %int128_2681 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2388 = torch.aten.view %2304, %2387 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2388, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_2682 = torch.constant.int 4 - %2389 = torch.aten.mul.int %int4_2682, %398 : !torch.int, !torch.int -> !torch.int - %int32_2683 = torch.constant.int 32 - %int8_2684 = torch.constant.int 8 - %int128_2685 = torch.constant.int 128 - %2390 = torch.prim.ListConstruct %2389, %int32_2683, %int8_2684, %int128_2685 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2391 = torch.aten.view %2388, %2390 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2391, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_2686 = torch.constant.int 1 - %int1_2687 = torch.constant.int 1 - %2392 = torch.aten.add.Scalar %2362, %int1_2686, %int1_2687 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2392, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_2688 = torch.constant.int 4 - %2393 = torch.aten.mul.int %int4_2688, %398 : !torch.int, !torch.int -> !torch.int - %2394 = torch.prim.ListConstruct %2393 : (!torch.int) -> !torch.list - %2395 = torch.aten.view %2392, %2394 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2395, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %2396 = torch.prim.ListConstruct %2395 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_2689 = torch.constant.bool false - %2397 = torch.aten.index_put %2386, %2396, %2391, %false_2689 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2397, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int32_2681 = torch.constant.int 32 + %int128_2682 = torch.constant.int 128 + %2387 = torch.prim.ListConstruct %497, %int8_2680, %int32_2681, %int128_2682 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2388 = torch.aten.view %2386, %2387 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2388, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %2389 = torch.prim.ListConstruct %2378 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_2683 = torch.constant.bool false + %2390 = torch.aten.index_put %2388, %2389, %2384, %false_2683 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2390, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_2684 = torch.constant.int 32 + %int2_2685 = torch.constant.int 2 + %int8_2686 = torch.constant.int 8 + %int32_2687 = torch.constant.int 32 + %int128_2688 = torch.constant.int 128 + %2391 = torch.prim.ListConstruct %297, %int32_2684, %int2_2685, %int8_2686, %int32_2687, %int128_2688 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2392 = torch.aten.view %2390, %2391 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2392, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_2689 = torch.constant.int 2097152 + %2393 = torch.prim.ListConstruct %297, %int2097152_2689 : (!torch.int, !torch.int) -> !torch.list + %2394 = torch.aten.view %2392, %2393 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2394, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> %int32_2690 = torch.constant.int 32 %int2_2691 = torch.constant.int 2 - %int32_2692 = torch.constant.int 32 - %int8_2693 = torch.constant.int 8 + %int8_2692 = torch.constant.int 8 + %int32_2693 = torch.constant.int 32 %int128_2694 = torch.constant.int 128 - %2398 = torch.prim.ListConstruct %389, %int32_2690, %int2_2691, %int32_2692, %int8_2693, %int128_2694 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2399 = torch.aten.view %2397, %2398 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2399, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2695 = torch.constant.int 2097152 - %2400 = torch.prim.ListConstruct %389, %int2097152_2695 : (!torch.int, !torch.int) -> !torch.list - %2401 = torch.aten.view %2399, %2400 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2401, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_2696 = torch.constant.int -2 - %2402 = torch.aten.unsqueeze %2360, %int-2_2696 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2402, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_2697 = torch.constant.int 4 - %int8_2698 = torch.constant.int 8 - %int4_2699 = torch.constant.int 4 - %int128_2700 = torch.constant.int 128 - %2403 = torch.prim.ListConstruct %int4_2697, %2345, %int8_2698, %int4_2699, %int128_2700 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2701 = torch.constant.bool false - %2404 = torch.aten.expand %2402, %2403, %false_2701 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2404, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_2702 = torch.constant.int 0 - %2405 = torch.aten.clone %2404, %int0_2702 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2405, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_2703 = torch.constant.int 4 - %int32_2704 = torch.constant.int 32 - %int128_2705 = torch.constant.int 128 - %2406 = torch.prim.ListConstruct %int4_2703, %2345, %int32_2704, %int128_2705 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2407 = torch.aten._unsafe_view %2405, %2406 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2407, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_2706 = torch.constant.int -2 - %2408 = torch.aten.unsqueeze %2304, %int-2_2706 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2408, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_2707 = torch.constant.int 1 - %2409 = torch.aten.size.int %2298, %int1_2707 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_2708 = torch.constant.int 4 + %2395 = torch.prim.ListConstruct %297, %int32_2690, %int2_2691, %int8_2692, %int32_2693, %int128_2694 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2396 = torch.aten.view %2394, %2395 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2396, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_2695 = torch.constant.int 8 + %int32_2696 = torch.constant.int 32 + %int128_2697 = torch.constant.int 128 + %2397 = torch.prim.ListConstruct %497, %int8_2695, %int32_2696, %int128_2697 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2398 = torch.aten.view %2396, %2397 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2398, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_2698 = torch.constant.int 32 + %2399 = torch.aten.mul.Scalar %arg2, %int32_2698 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2399, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int7_2699 = torch.constant.int 7 + %int1_2700 = torch.constant.int 1 + %2400 = torch.aten.add.Scalar %2399, %int7_2699, %int1_2700 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2400, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_2701 = torch.constant.int 2 + %2401 = torch.aten.mul.Scalar %2400, %int2_2701 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2401, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_2702 = torch.constant.int 1 + %int1_2703 = torch.constant.int 1 + %2402 = torch.aten.add.Scalar %2401, %int1_2702, %int1_2703 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2402, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %2403 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %2404 = torch.aten.view %2402, %2403 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %2404, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_2704 = torch.constant.int 4 + %int32_2705 = torch.constant.int 32 + %int8_2706 = torch.constant.int 8 + %int128_2707 = torch.constant.int 128 + %2405 = torch.prim.ListConstruct %int4_2704, %296, %int32_2705, %int8_2706, %int128_2707 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2406 = torch.aten.view %2246, %2405 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2406, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_2708 = torch.constant.int 32 %int8_2709 = torch.constant.int 8 - %int4_2710 = torch.constant.int 4 - %int128_2711 = torch.constant.int 128 - %2410 = torch.prim.ListConstruct %int4_2708, %2409, %int8_2709, %int4_2710, %int128_2711 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2712 = torch.constant.bool false - %2411 = torch.aten.expand %2408, %2410, %false_2712 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2411, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_2713 = torch.constant.int 0 - %2412 = torch.aten.clone %2411, %int0_2713 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2412, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_2714 = torch.constant.int 4 + %int128_2710 = torch.constant.int 128 + %2407 = torch.prim.ListConstruct %504, %int32_2708, %int8_2709, %int128_2710 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2408 = torch.aten.view %2406, %2407 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %2408, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_2711 = torch.constant.int 1 + %int2_2712 = torch.constant.int 2 + %2409 = torch.aten.transpose.int %2408, %int1_2711, %int2_2712 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2409, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_2713 = torch.constant.int 5 + %2410 = torch.prims.convert_element_type %2409, %int5_2713 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2410, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %2411 = torch.prim.ListConstruct %2404 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_2714 = torch.constant.bool false + %2412 = torch.aten.index_put %2398, %2411, %2410, %false_2714 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2412, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> %int32_2715 = torch.constant.int 32 - %int128_2716 = torch.constant.int 128 - %2413 = torch.prim.ListConstruct %int4_2714, %2409, %int32_2715, %int128_2716 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2414 = torch.aten._unsafe_view %2412, %2413 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2414, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_2717 = torch.constant.int 1 - %int2_2718 = torch.constant.int 2 - %2415 = torch.aten.transpose.int %2332, %int1_2717, %int2_2718 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2415, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_2719 = torch.constant.int 1 - %int2_2720 = torch.constant.int 2 - %2416 = torch.aten.transpose.int %2407, %int1_2719, %int2_2720 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2416, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_2721 = torch.constant.int 1 - %int2_2722 = torch.constant.int 2 - %2417 = torch.aten.transpose.int %2414, %int1_2721, %int2_2722 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2417, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_2723 = torch.constant.float 0.000000e+00 - %true_2724 = torch.constant.bool true - %none_2725 = torch.constant.none - %none_2726 = torch.constant.none - %2418:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2415, %2416, %2417, %float0.000000e00_2723, %true_2724, %none_2725, %none_2726) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %2418#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_2727 = torch.constant.int 1 - %int2_2728 = torch.constant.int 2 - %2419 = torch.aten.transpose.int %2418#0, %int1_2727, %int2_2728 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2419, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_2729 = torch.constant.int 4 - %int4096_2730 = torch.constant.int 4096 - %2420 = torch.prim.ListConstruct %int4_2729, %2317, %int4096_2730 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2421 = torch.aten.view %2419, %2420 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2421, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int2_2716 = torch.constant.int 2 + %int8_2717 = torch.constant.int 8 + %int32_2718 = torch.constant.int 32 + %int128_2719 = torch.constant.int 128 + %2413 = torch.prim.ListConstruct %297, %int32_2715, %int2_2716, %int8_2717, %int32_2718, %int128_2719 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2414 = torch.aten.view %2412, %2413 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2414, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_2720 = torch.constant.int 2097152 + %2415 = torch.prim.ListConstruct %297, %int2097152_2720 : (!torch.int, !torch.int) -> !torch.list + %2416 = torch.aten.view %2414, %2415 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2416, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_2721 = torch.constant.int -2 + %2417 = torch.aten.unsqueeze %2372, %int-2_2721 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2417, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_2722 = torch.constant.int 4 + %int8_2723 = torch.constant.int 8 + %int4_2724 = torch.constant.int 4 + %int128_2725 = torch.constant.int 128 + %2418 = torch.prim.ListConstruct %int4_2722, %298, %int8_2723, %int4_2724, %int128_2725 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_2726 = torch.constant.bool false + %2419 = torch.aten.expand %2417, %2418, %false_2726 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2419, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_2727 = torch.constant.int 0 + %2420 = torch.aten.clone %2419, %int0_2727 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2420, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_2728 = torch.constant.int 4 + %int32_2729 = torch.constant.int 32 + %int128_2730 = torch.constant.int 128 + %2421 = torch.prim.ListConstruct %int4_2728, %298, %int32_2729, %int128_2730 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2422 = torch.aten._unsafe_view %2420, %2421 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2422, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int-2_2731 = torch.constant.int -2 - %int-1_2732 = torch.constant.int -1 - %2422 = torch.aten.transpose.int %95, %int-2_2731, %int-1_2732 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_2733 = torch.constant.int 4 - %2423 = torch.aten.mul.int %int4_2733, %2317 : !torch.int, !torch.int -> !torch.int - %int4096_2734 = torch.constant.int 4096 - %2424 = torch.prim.ListConstruct %2423, %int4096_2734 : (!torch.int, !torch.int) -> !torch.list - %2425 = torch.aten.view %2421, %2424 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2425, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2426 = torch.aten.mm %2425, %2422 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2426, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_2735 = torch.constant.int 4 - %int4096_2736 = torch.constant.int 4096 - %2427 = torch.prim.ListConstruct %int4_2735, %2317, %int4096_2736 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2428 = torch.aten.view %2426, %2427 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2428, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_2737 = torch.constant.int 1 - %2429 = torch.aten.add.Tensor %2267, %2428, %int1_2737 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2429, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_2738 = torch.constant.int 6 - %2430 = torch.prims.convert_element_type %2429, %int6_2738 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2430, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_2739 = torch.constant.int 2 - %2431 = torch.aten.pow.Tensor_Scalar %2430, %int2_2739 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2431, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_2740 = torch.constant.int -1 - %2432 = torch.prim.ListConstruct %int-1_2740 : (!torch.int) -> !torch.list - %true_2741 = torch.constant.bool true - %none_2742 = torch.constant.none - %2433 = torch.aten.mean.dim %2431, %2432, %true_2741, %none_2742 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2433, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_2743 = torch.constant.float 9.9999997473787516E-6 - %int1_2744 = torch.constant.int 1 - %2434 = torch.aten.add.Scalar %2433, %float9.999990e-06_2743, %int1_2744 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2434, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2435 = torch.aten.rsqrt %2434 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2435, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2436 = torch.aten.mul.Tensor %2430, %2435 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2436, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2745 = torch.constant.int 5 - %2437 = torch.prims.convert_element_type %2436, %int5_2745 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2437, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %2438 = torch.aten.mul.Tensor %96, %2437 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2438, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2746 = torch.constant.int 5 - %2439 = torch.prims.convert_element_type %2438, %int5_2746 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2439, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_2747 = torch.constant.int -2 - %int-1_2748 = torch.constant.int -1 - %2440 = torch.aten.transpose.int %97, %int-2_2747, %int-1_2748 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_2749 = torch.constant.int 4 - %2441 = torch.aten.mul.int %int4_2749, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2750 = torch.constant.int 4096 - %2442 = torch.prim.ListConstruct %2441, %int4096_2750 : (!torch.int, !torch.int) -> !torch.list - %2443 = torch.aten.view %2439, %2442 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2443, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2444 = torch.aten.mm %2443, %2440 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2444, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_2751 = torch.constant.int 4 - %int14336_2752 = torch.constant.int 14336 - %2445 = torch.prim.ListConstruct %int4_2751, %306, %int14336_2752 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2446 = torch.aten.view %2444, %2445 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2446, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %2447 = torch.aten.silu %2446 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2447, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_2753 = torch.constant.int -2 - %int-1_2754 = torch.constant.int -1 - %2448 = torch.aten.transpose.int %98, %int-2_2753, %int-1_2754 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_2755 = torch.constant.int 4 - %2449 = torch.aten.mul.int %int4_2755, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2756 = torch.constant.int 4096 - %2450 = torch.prim.ListConstruct %2449, %int4096_2756 : (!torch.int, !torch.int) -> !torch.list - %2451 = torch.aten.view %2439, %2450 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2451, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2452 = torch.aten.mm %2451, %2448 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2452, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_2757 = torch.constant.int 4 - %int14336_2758 = torch.constant.int 14336 - %2453 = torch.prim.ListConstruct %int4_2757, %306, %int14336_2758 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2454 = torch.aten.view %2452, %2453 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2454, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %2455 = torch.aten.mul.Tensor %2447, %2454 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2455, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_2759 = torch.constant.int -2 - %int-1_2760 = torch.constant.int -1 - %2456 = torch.aten.transpose.int %99, %int-2_2759, %int-1_2760 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_2761 = torch.constant.int 1 - %2457 = torch.aten.size.int %2446, %int1_2761 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_2762 = torch.constant.int 4 - %2458 = torch.aten.mul.int %int4_2762, %2457 : !torch.int, !torch.int -> !torch.int - %int14336_2763 = torch.constant.int 14336 - %2459 = torch.prim.ListConstruct %2458, %int14336_2763 : (!torch.int, !torch.int) -> !torch.list - %2460 = torch.aten.view %2455, %2459 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2460, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %2461 = torch.aten.mm %2460, %2456 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2461, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_2764 = torch.constant.int 4 - %int4096_2765 = torch.constant.int 4096 - %2462 = torch.prim.ListConstruct %int4_2764, %2457, %int4096_2765 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2463 = torch.aten.view %2461, %2462 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2463, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_2766 = torch.constant.int 1 - %2464 = torch.aten.add.Tensor %2429, %2463, %int1_2766 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2464, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_2767 = torch.constant.int 6 - %2465 = torch.prims.convert_element_type %2464, %int6_2767 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2465, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_2768 = torch.constant.int 2 - %2466 = torch.aten.pow.Tensor_Scalar %2465, %int2_2768 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2466, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_2769 = torch.constant.int -1 - %2467 = torch.prim.ListConstruct %int-1_2769 : (!torch.int) -> !torch.list - %true_2770 = torch.constant.bool true - %none_2771 = torch.constant.none - %2468 = torch.aten.mean.dim %2466, %2467, %true_2770, %none_2771 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2468, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_2772 = torch.constant.float 9.9999997473787516E-6 - %int1_2773 = torch.constant.int 1 - %2469 = torch.aten.add.Scalar %2468, %float9.999990e-06_2772, %int1_2773 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2469, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2470 = torch.aten.rsqrt %2469 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2470, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2471 = torch.aten.mul.Tensor %2465, %2470 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2471, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2774 = torch.constant.int 5 - %2472 = torch.prims.convert_element_type %2471, %int5_2774 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2472, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %2473 = torch.aten.mul.Tensor %100, %2472 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2473, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2775 = torch.constant.int 5 - %2474 = torch.prims.convert_element_type %2473, %int5_2775 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2474, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %2423 = torch.aten.unsqueeze %2246, %int-2_2731 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2423, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_2732 = torch.constant.int 4 + %int8_2733 = torch.constant.int 8 + %int4_2734 = torch.constant.int 4 + %int128_2735 = torch.constant.int 128 + %2424 = torch.prim.ListConstruct %int4_2732, %298, %int8_2733, %int4_2734, %int128_2735 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_2736 = torch.constant.bool false + %2425 = torch.aten.expand %2423, %2424, %false_2736 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2425, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_2737 = torch.constant.int 0 + %2426 = torch.aten.clone %2425, %int0_2737 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2426, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_2738 = torch.constant.int 4 + %int32_2739 = torch.constant.int 32 + %int128_2740 = torch.constant.int 128 + %2427 = torch.prim.ListConstruct %int4_2738, %298, %int32_2739, %int128_2740 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2428 = torch.aten._unsafe_view %2426, %2427 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2428, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_2741 = torch.constant.int 1 + %int2_2742 = torch.constant.int 2 + %2429 = torch.aten.transpose.int %2309, %int1_2741, %int2_2742 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2429, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_2743 = torch.constant.int 1 + %int2_2744 = torch.constant.int 2 + %2430 = torch.aten.transpose.int %2422, %int1_2743, %int2_2744 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2430, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_2745 = torch.constant.int 1 + %int2_2746 = torch.constant.int 2 + %2431 = torch.aten.transpose.int %2428, %int1_2745, %int2_2746 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2431, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_2747 = torch.constant.float 0.000000e+00 + %false_2748 = torch.constant.bool false + %none_2749 = torch.constant.none + %2432:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2429, %2430, %2431, %float0.000000e00_2747, %false_2748, %327, %none_2749) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %2432#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_2750 = torch.constant.int 1 + %int2_2751 = torch.constant.int 2 + %2433 = torch.aten.transpose.int %2432#0, %int1_2750, %int2_2751 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2433, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_2752 = torch.constant.int 4 + %int4096_2753 = torch.constant.int 4096 + %2434 = torch.prim.ListConstruct %int4_2752, %298, %int4096_2753 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2435 = torch.aten.view %2433, %2434 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2435, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_2754 = torch.constant.int -2 + %int-1_2755 = torch.constant.int -1 + %2436 = torch.aten.transpose.int %69, %int-2_2754, %int-1_2755 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2756 = torch.constant.int 5 + %2437 = torch.prims.convert_element_type %2436, %int5_2756 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_2757 = torch.constant.int 4096 + %2438 = torch.prim.ListConstruct %342, %int4096_2757 : (!torch.int, !torch.int) -> !torch.list + %2439 = torch.aten.view %2435, %2438 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2439, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2440 = torch.aten.mm %2439, %2437 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2440, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_2758 = torch.constant.int 4 + %int4096_2759 = torch.constant.int 4096 + %2441 = torch.prim.ListConstruct %int4_2758, %298, %int4096_2759 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2442 = torch.aten.view %2440, %2441 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2442, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_2760 = torch.constant.int 1 + %2443 = torch.aten.add.Tensor %2209, %2442, %int1_2760 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2443, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_2761 = torch.constant.int 6 + %2444 = torch.prims.convert_element_type %2443, %int6_2761 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2444, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_2762 = torch.constant.int 2 + %2445 = torch.aten.pow.Tensor_Scalar %2444, %int2_2762 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2445, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_2763 = torch.constant.int -1 + %2446 = torch.prim.ListConstruct %int-1_2763 : (!torch.int) -> !torch.list + %true_2764 = torch.constant.bool true + %none_2765 = torch.constant.none + %2447 = torch.aten.mean.dim %2445, %2446, %true_2764, %none_2765 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2447, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_2766 = torch.constant.float 9.9999997473787516E-6 + %int1_2767 = torch.constant.int 1 + %2448 = torch.aten.add.Scalar %2447, %float9.999990e-06_2766, %int1_2767 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2448, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %2449 = torch.aten.rsqrt %2448 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2449, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %2450 = torch.aten.mul.Tensor %2444, %2449 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2450, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_2768 = torch.constant.int 5 + %2451 = torch.prims.convert_element_type %2450, %int5_2768 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2451, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %2452 = torch.aten.mul.Tensor %70, %2451 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2452, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_2769 = torch.constant.int 5 + %2453 = torch.prims.convert_element_type %2452, %int5_2769 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2453, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_2770 = torch.constant.int -2 + %int-1_2771 = torch.constant.int -1 + %2454 = torch.aten.transpose.int %71, %int-2_2770, %int-1_2771 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_2772 = torch.constant.int 5 + %2455 = torch.prims.convert_element_type %2454, %int5_2772 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_2773 = torch.constant.int 4096 + %2456 = torch.prim.ListConstruct %342, %int4096_2773 : (!torch.int, !torch.int) -> !torch.list + %2457 = torch.aten.view %2453, %2456 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2457, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2458 = torch.aten.mm %2457, %2455 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %2458, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_2774 = torch.constant.int 4 + %int14336_2775 = torch.constant.int 14336 + %2459 = torch.prim.ListConstruct %int4_2774, %298, %int14336_2775 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2460 = torch.aten.view %2458, %2459 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %2460, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %2461 = torch.aten.silu %2460 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %2461, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> %int-2_2776 = torch.constant.int -2 %int-1_2777 = torch.constant.int -1 - %2475 = torch.aten.transpose.int %101, %int-2_2776, %int-1_2777 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_2778 = torch.constant.int 4 - %2476 = torch.aten.mul.int %int4_2778, %306 : !torch.int, !torch.int -> !torch.int + %2462 = torch.aten.transpose.int %72, %int-2_2776, %int-1_2777 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_2778 = torch.constant.int 5 + %2463 = torch.prims.convert_element_type %2462, %int5_2778 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> %int4096_2779 = torch.constant.int 4096 - %2477 = torch.prim.ListConstruct %2476, %int4096_2779 : (!torch.int, !torch.int) -> !torch.list - %2478 = torch.aten.view %2474, %2477 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2478, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2479 = torch.aten.mm %2478, %2475 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2479, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2464 = torch.prim.ListConstruct %342, %int4096_2779 : (!torch.int, !torch.int) -> !torch.list + %2465 = torch.aten.view %2453, %2464 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2465, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2466 = torch.aten.mm %2465, %2463 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %2466, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> %int4_2780 = torch.constant.int 4 - %int4096_2781 = torch.constant.int 4096 - %2480 = torch.prim.ListConstruct %int4_2780, %306, %int4096_2781 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2481 = torch.aten.view %2479, %2480 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2481, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int14336_2781 = torch.constant.int 14336 + %2467 = torch.prim.ListConstruct %int4_2780, %298, %int14336_2781 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2468 = torch.aten.view %2466, %2467 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %2468, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %2469 = torch.aten.mul.Tensor %2461, %2468 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %2469, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> %int-2_2782 = torch.constant.int -2 %int-1_2783 = torch.constant.int -1 - %2482 = torch.aten.transpose.int %102, %int-2_2782, %int-1_2783 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_2784 = torch.constant.int 4 - %2483 = torch.aten.mul.int %int4_2784, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2785 = torch.constant.int 4096 - %2484 = torch.prim.ListConstruct %2483, %int4096_2785 : (!torch.int, !torch.int) -> !torch.list - %2485 = torch.aten.view %2474, %2484 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2485, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2486 = torch.aten.mm %2485, %2482 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %2486, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %2470 = torch.aten.transpose.int %73, %int-2_2782, %int-1_2783 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_2784 = torch.constant.int 5 + %2471 = torch.prims.convert_element_type %2470, %int5_2784 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_2785 = torch.constant.int 14336 + %2472 = torch.prim.ListConstruct %342, %int14336_2785 : (!torch.int, !torch.int) -> !torch.list + %2473 = torch.aten.view %2469, %2472 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %2473, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %2474 = torch.aten.mm %2473, %2471 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2474, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> %int4_2786 = torch.constant.int 4 - %int1024_2787 = torch.constant.int 1024 - %2487 = torch.prim.ListConstruct %int4_2786, %306, %int1024_2787 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2488 = torch.aten.view %2486, %2487 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %2488, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_2788 = torch.constant.int -2 - %int-1_2789 = torch.constant.int -1 - %2489 = torch.aten.transpose.int %103, %int-2_2788, %int-1_2789 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_2790 = torch.constant.int 4 - %2490 = torch.aten.mul.int %int4_2790, %306 : !torch.int, !torch.int -> !torch.int - %int4096_2791 = torch.constant.int 4096 - %2491 = torch.prim.ListConstruct %2490, %int4096_2791 : (!torch.int, !torch.int) -> !torch.list - %2492 = torch.aten.view %2474, %2491 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2492, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2493 = torch.aten.mm %2492, %2489 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %2493, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_2792 = torch.constant.int 4 - %int1024_2793 = torch.constant.int 1024 - %2494 = torch.prim.ListConstruct %int4_2792, %306, %int1024_2793 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2495 = torch.aten.view %2493, %2494 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %2495, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_2794 = torch.constant.int 4 - %int32_2795 = torch.constant.int 32 - %int128_2796 = torch.constant.int 128 - %2496 = torch.prim.ListConstruct %int4_2794, %306, %int32_2795, %int128_2796 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2497 = torch.aten.view %2481, %2496 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2497, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_2797 = torch.constant.int 4 - %int8_2798 = torch.constant.int 8 - %int128_2799 = torch.constant.int 128 - %2498 = torch.prim.ListConstruct %int4_2797, %306, %int8_2798, %int128_2799 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2499 = torch.aten.view %2488, %2498 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2499, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_2800 = torch.constant.int 4 - %int8_2801 = torch.constant.int 8 - %int128_2802 = torch.constant.int 128 - %2500 = torch.prim.ListConstruct %int4_2800, %306, %int8_2801, %int128_2802 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2501 = torch.aten.view %2495, %2500 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2501, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_2803 = torch.constant.int 131072 - %none_2804 = torch.constant.none - %none_2805 = torch.constant.none - %cpu_2806 = torch.constant.device "cpu" - %false_2807 = torch.constant.bool false - %2502 = torch.aten.arange %int131072_2803, %none_2804, %none_2805, %cpu_2806, %false_2807 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_2808 = torch.constant.int 0 - %int128_2809 = torch.constant.int 128 - %none_2810 = torch.constant.none - %none_2811 = torch.constant.none - %cpu_2812 = torch.constant.device "cpu" - %false_2813 = torch.constant.bool false - %2503 = torch.aten.arange.start %int0_2808, %int128_2809, %none_2810, %none_2811, %cpu_2812, %false_2813 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_2814 = torch.constant.int 2 - %2504 = torch.aten.floor_divide.Scalar %2503, %int2_2814 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_2815 = torch.constant.int 6 - %2505 = torch.prims.convert_element_type %2504, %int6_2815 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_2816 = torch.constant.int 128 - %2506 = torch.aten.div.Scalar %2505, %int128_2816 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_2817 = torch.constant.float 2.000000e+00 - %2507 = torch.aten.mul.Scalar %2506, %float2.000000e00_2817 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_2818 = torch.constant.float 5.000000e+05 - %2508 = torch.aten.pow.Scalar %float5.000000e05_2818, %2507 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %2509 = torch.aten.reciprocal %2508 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_2819 = torch.constant.float 1.000000e+00 - %2510 = torch.aten.mul.Scalar %2509, %float1.000000e00_2819 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_2820 = torch.constant.int 1 - %2511 = torch.aten.unsqueeze %2502, %int1_2820 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_2821 = torch.constant.int 0 - %2512 = torch.aten.unsqueeze %2510, %int0_2821 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %2513 = torch.aten.mul.Tensor %2511, %2512 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_2822 = torch.constant.int 1 - %2514 = torch.aten.size.int %2481, %int1_2822 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_2823 = torch.constant.int 0 - %2515 = torch.aten.add.int %int0_2823, %2514 : !torch.int, !torch.int -> !torch.int - %int0_2824 = torch.constant.int 0 - %int0_2825 = torch.constant.int 0 - %int1_2826 = torch.constant.int 1 - %2516 = torch.aten.slice.Tensor %2513, %int0_2824, %int0_2825, %2515, %int1_2826 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2516, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_2827 = torch.constant.int 1 - %int0_2828 = torch.constant.int 0 - %int9223372036854775807_2829 = torch.constant.int 9223372036854775807 - %int1_2830 = torch.constant.int 1 - %2517 = torch.aten.slice.Tensor %2516, %int1_2827, %int0_2828, %int9223372036854775807_2829, %int1_2830 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2517, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_2831 = torch.constant.int 1 - %int0_2832 = torch.constant.int 0 - %int9223372036854775807_2833 = torch.constant.int 9223372036854775807 - %int1_2834 = torch.constant.int 1 - %2518 = torch.aten.slice.Tensor %2517, %int1_2831, %int0_2832, %int9223372036854775807_2833, %int1_2834 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2518, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_2835 = torch.constant.int 0 - %2519 = torch.aten.unsqueeze %2518, %int0_2835 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2519, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_2836 = torch.constant.int 1 - %int0_2837 = torch.constant.int 0 - %int9223372036854775807_2838 = torch.constant.int 9223372036854775807 - %int1_2839 = torch.constant.int 1 - %2520 = torch.aten.slice.Tensor %2519, %int1_2836, %int0_2837, %int9223372036854775807_2838, %int1_2839 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2520, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_2840 = torch.constant.int 2 - %int0_2841 = torch.constant.int 0 - %int9223372036854775807_2842 = torch.constant.int 9223372036854775807 - %int1_2843 = torch.constant.int 1 - %2521 = torch.aten.slice.Tensor %2520, %int2_2840, %int0_2841, %int9223372036854775807_2842, %int1_2843 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2521, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_2844 = torch.constant.int 4 + %int4096_2787 = torch.constant.int 4096 + %2475 = torch.prim.ListConstruct %int4_2786, %298, %int4096_2787 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2476 = torch.aten.view %2474, %2475 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2476, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_2788 = torch.constant.int 1 + %2477 = torch.aten.add.Tensor %2443, %2476, %int1_2788 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2477, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_2789 = torch.constant.int 6 + %2478 = torch.prims.convert_element_type %2477, %int6_2789 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2478, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_2790 = torch.constant.int 2 + %2479 = torch.aten.pow.Tensor_Scalar %2478, %int2_2790 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2479, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_2791 = torch.constant.int -1 + %2480 = torch.prim.ListConstruct %int-1_2791 : (!torch.int) -> !torch.list + %true_2792 = torch.constant.bool true + %none_2793 = torch.constant.none + %2481 = torch.aten.mean.dim %2479, %2480, %true_2792, %none_2793 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2481, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_2794 = torch.constant.float 9.9999997473787516E-6 + %int1_2795 = torch.constant.int 1 + %2482 = torch.aten.add.Scalar %2481, %float9.999990e-06_2794, %int1_2795 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2482, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %2483 = torch.aten.rsqrt %2482 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2483, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %2484 = torch.aten.mul.Tensor %2478, %2483 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2484, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_2796 = torch.constant.int 5 + %2485 = torch.prims.convert_element_type %2484, %int5_2796 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2485, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %2486 = torch.aten.mul.Tensor %74, %2485 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2486, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_2797 = torch.constant.int 5 + %2487 = torch.prims.convert_element_type %2486, %int5_2797 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2487, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_2798 = torch.constant.int -2 + %int-1_2799 = torch.constant.int -1 + %2488 = torch.aten.transpose.int %75, %int-2_2798, %int-1_2799 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2800 = torch.constant.int 5 + %2489 = torch.prims.convert_element_type %2488, %int5_2800 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_2801 = torch.constant.int 4096 + %2490 = torch.prim.ListConstruct %342, %int4096_2801 : (!torch.int, !torch.int) -> !torch.list + %2491 = torch.aten.view %2487, %2490 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2491, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2492 = torch.aten.mm %2491, %2489 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2492, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_2802 = torch.constant.int 4 + %int4096_2803 = torch.constant.int 4096 + %2493 = torch.prim.ListConstruct %int4_2802, %298, %int4096_2803 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2494 = torch.aten.view %2492, %2493 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2494, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_2804 = torch.constant.int -2 + %int-1_2805 = torch.constant.int -1 + %2495 = torch.aten.transpose.int %76, %int-2_2804, %int-1_2805 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_2806 = torch.constant.int 5 + %2496 = torch.prims.convert_element_type %2495, %int5_2806 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_2807 = torch.constant.int 4096 + %2497 = torch.prim.ListConstruct %342, %int4096_2807 : (!torch.int, !torch.int) -> !torch.list + %2498 = torch.aten.view %2487, %2497 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2498, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2499 = torch.aten.mm %2498, %2496 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %2499, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_2808 = torch.constant.int 4 + %int1024_2809 = torch.constant.int 1024 + %2500 = torch.prim.ListConstruct %int4_2808, %298, %int1024_2809 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2501 = torch.aten.view %2499, %2500 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %2501, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_2810 = torch.constant.int -2 + %int-1_2811 = torch.constant.int -1 + %2502 = torch.aten.transpose.int %77, %int-2_2810, %int-1_2811 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_2812 = torch.constant.int 5 + %2503 = torch.prims.convert_element_type %2502, %int5_2812 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_2813 = torch.constant.int 4096 + %2504 = torch.prim.ListConstruct %342, %int4096_2813 : (!torch.int, !torch.int) -> !torch.list + %2505 = torch.aten.view %2487, %2504 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2505, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2506 = torch.aten.mm %2505, %2503 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %2506, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_2814 = torch.constant.int 4 + %int1024_2815 = torch.constant.int 1024 + %2507 = torch.prim.ListConstruct %int4_2814, %298, %int1024_2815 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2508 = torch.aten.view %2506, %2507 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %2508, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_2816 = torch.constant.int 4 + %int32_2817 = torch.constant.int 32 + %int128_2818 = torch.constant.int 128 + %2509 = torch.prim.ListConstruct %int4_2816, %298, %int32_2817, %int128_2818 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2510 = torch.aten.view %2494, %2509 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2510, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_2819 = torch.constant.int 4 + %int8_2820 = torch.constant.int 8 + %int128_2821 = torch.constant.int 128 + %2511 = torch.prim.ListConstruct %int4_2819, %298, %int8_2820, %int128_2821 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2512 = torch.aten.view %2501, %2511 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2512, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_2822 = torch.constant.int 4 + %int8_2823 = torch.constant.int 8 + %int128_2824 = torch.constant.int 128 + %2513 = torch.prim.ListConstruct %int4_2822, %298, %int8_2823, %int128_2824 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2514 = torch.aten.view %2508, %2513 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2514, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_2825 = torch.constant.int 131072 + %none_2826 = torch.constant.none + %none_2827 = torch.constant.none + %cpu_2828 = torch.constant.device "cpu" + %false_2829 = torch.constant.bool false + %2515 = torch.aten.arange %int131072_2825, %none_2826, %none_2827, %cpu_2828, %false_2829 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_2830 = torch.constant.int 0 + %int128_2831 = torch.constant.int 128 + %int2_2832 = torch.constant.int 2 + %int4_2833 = torch.constant.int 4 + %none_2834 = torch.constant.none + %cpu_2835 = torch.constant.device "cpu" + %false_2836 = torch.constant.bool false + %2516 = torch.aten.arange.start_step %int0_2830, %int128_2831, %int2_2832, %int4_2833, %none_2834, %cpu_2835, %false_2836 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_2837 = torch.constant.int 6 + %2517 = torch.prims.convert_element_type %2516, %int6_2837 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_2838 = torch.constant.int 128 + %2518 = torch.aten.div.Scalar %2517, %int128_2838 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_2839 = torch.constant.float 5.000000e+05 + %2519 = torch.aten.pow.Scalar %float5.000000e05_2839, %2518 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2520 = torch.aten.reciprocal %2519 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_2840 = torch.constant.float 1.000000e+00 + %2521 = torch.aten.mul.Scalar %2520, %float1.000000e00_2840 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %2522 = torch.aten.reciprocal %2521 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_2841 = torch.constant.float 6.2831853071795862 + %2523 = torch.aten.mul.Scalar %2522, %float6.283190e00_2841 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_2842 = torch.constant.float 8.192000e+03 + %2524 = torch.aten.gt.Scalar %2523, %float8.192000e03_2842 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_2843 = torch.constant.int 8 + %2525 = torch.aten.div.Scalar %2521, %int8_2843 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %2526 = torch.aten.where.self %2524, %2525, %2521 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2527 = torch.aten.reciprocal %2523 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_2844 = torch.constant.int 8192 + %2528 = torch.aten.mul.Scalar %2527, %int8192_2844 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_2845 = torch.constant.int 1 %int1_2846 = torch.constant.int 1 - %2522 = torch.prim.ListConstruct %int4_2844, %int1_2845, %int1_2846 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2523 = torch.aten.repeat %2521, %2522 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %2523, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_2847 = torch.constant.int 6 - %2524 = torch.prims.convert_element_type %2497, %int6_2847 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %2524, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %2525 = torch_c.to_builtin_tensor %2524 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %2526 = torch_c.to_builtin_tensor %2523 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %2527 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%2525, %2526) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %2528 = torch_c.from_builtin_tensor %2527 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %2528, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_2848 = torch.constant.int 5 - %2529 = torch.prims.convert_element_type %2528, %int5_2848 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2529, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_2849 = torch.constant.int 131072 - %none_2850 = torch.constant.none - %none_2851 = torch.constant.none - %cpu_2852 = torch.constant.device "cpu" - %false_2853 = torch.constant.bool false - %2530 = torch.aten.arange %int131072_2849, %none_2850, %none_2851, %cpu_2852, %false_2853 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_2854 = torch.constant.int 0 - %int128_2855 = torch.constant.int 128 - %none_2856 = torch.constant.none - %none_2857 = torch.constant.none - %cpu_2858 = torch.constant.device "cpu" - %false_2859 = torch.constant.bool false - %2531 = torch.aten.arange.start %int0_2854, %int128_2855, %none_2856, %none_2857, %cpu_2858, %false_2859 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_2860 = torch.constant.int 2 - %2532 = torch.aten.floor_divide.Scalar %2531, %int2_2860 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_2861 = torch.constant.int 6 - %2533 = torch.prims.convert_element_type %2532, %int6_2861 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_2862 = torch.constant.int 128 - %2534 = torch.aten.div.Scalar %2533, %int128_2862 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_2863 = torch.constant.float 2.000000e+00 - %2535 = torch.aten.mul.Scalar %2534, %float2.000000e00_2863 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_2864 = torch.constant.float 5.000000e+05 - %2536 = torch.aten.pow.Scalar %float5.000000e05_2864, %2535 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %2537 = torch.aten.reciprocal %2536 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_2865 = torch.constant.float 1.000000e+00 - %2538 = torch.aten.mul.Scalar %2537, %float1.000000e00_2865 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_2866 = torch.constant.int 1 - %2539 = torch.aten.unsqueeze %2530, %int1_2866 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_2867 = torch.constant.int 0 - %2540 = torch.aten.unsqueeze %2538, %int0_2867 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %2541 = torch.aten.mul.Tensor %2539, %2540 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %2529 = torch.aten.sub.Scalar %2528, %int1_2845, %int1_2846 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_2847 = torch.constant.int 3 + %2530 = torch.aten.div.Scalar %2529, %int3_2847 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_2848 = torch.constant.int 1 + %int1_2849 = torch.constant.int 1 + %2531 = torch.aten.rsub.Scalar %2530, %int1_2848, %int1_2849 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %2532 = torch.aten.mul.Tensor %2531, %2526 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_2850 = torch.constant.int 8 + %2533 = torch.aten.div.Scalar %2532, %int8_2850 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %2534 = torch.aten.mul.Tensor %2530, %2526 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_2851 = torch.constant.int 1 + %2535 = torch.aten.add.Tensor %2533, %2534, %int1_2851 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_2852 = torch.constant.float 2.048000e+03 + %2536 = torch.aten.lt.Scalar %2523, %float2.048000e03_2852 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2537 = torch.aten.bitwise_not %2536 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_2853 = torch.constant.float 8.192000e+03 + %2538 = torch.aten.gt.Scalar %2523, %float8.192000e03_2853 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2539 = torch.aten.bitwise_not %2538 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2540 = torch.aten.mul.Tensor %2537, %2539 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2541 = torch.aten.where.self %2540, %2535, %2526 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2542 = torch.prim.ListConstruct %2541, %2541 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_2854 = torch.constant.int -1 + %2543 = torch.aten.cat %2542, %int-1_2854 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_2855 = torch.constant.int 6 + %2544 = torch.prims.convert_element_type %2543, %int6_2855 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_2856 = torch.constant.int 1 + %2545 = torch.aten.unsqueeze %2515, %int1_2856 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_2857 = torch.constant.int 6 + %2546 = torch.prims.convert_element_type %2545, %int6_2857 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_2858 = torch.constant.int 0 + %2547 = torch.aten.unsqueeze %2544, %int0_2858 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_2859 = torch.constant.int 6 + %2548 = torch.prims.convert_element_type %2547, %int6_2859 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %2549 = torch.aten.mul.Tensor %2546, %2548 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %2550 = torch.aten.cos %2549 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_2860 = torch.constant.int 5 + %2551 = torch.prims.convert_element_type %2550, %int5_2860 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %2552 = torch.aten.sin %2549 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_2861 = torch.constant.int 5 + %2553 = torch.prims.convert_element_type %2552, %int5_2861 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_2862 = torch.constant.int 0 + %int0_2863 = torch.constant.int 0 + %int1_2864 = torch.constant.int 1 + %2554 = torch.aten.slice.Tensor %2551, %int0_2862, %int0_2863, %298, %int1_2864 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2554, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_2865 = torch.constant.int 1 + %int0_2866 = torch.constant.int 0 + %int9223372036854775807_2867 = torch.constant.int 9223372036854775807 %int1_2868 = torch.constant.int 1 - %2542 = torch.aten.size.int %2488, %int1_2868 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int + %2555 = torch.aten.slice.Tensor %2554, %int1_2865, %int0_2866, %int9223372036854775807_2867, %int1_2868 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2555, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int0_2869 = torch.constant.int 0 - %2543 = torch.aten.add.int %int0_2869, %2542 : !torch.int, !torch.int -> !torch.int %int0_2870 = torch.constant.int 0 - %int0_2871 = torch.constant.int 0 + %int1_2871 = torch.constant.int 1 + %2556 = torch.aten.slice.Tensor %2553, %int0_2869, %int0_2870, %298, %int1_2871 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2556, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_2872 = torch.constant.int 1 - %2544 = torch.aten.slice.Tensor %2541, %int0_2870, %int0_2871, %2543, %int1_2872 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2544, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_2873 = torch.constant.int 1 - %int0_2874 = torch.constant.int 0 - %int9223372036854775807_2875 = torch.constant.int 9223372036854775807 - %int1_2876 = torch.constant.int 1 - %2545 = torch.aten.slice.Tensor %2544, %int1_2873, %int0_2874, %int9223372036854775807_2875, %int1_2876 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2545, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %int0_2873 = torch.constant.int 0 + %int9223372036854775807_2874 = torch.constant.int 9223372036854775807 + %int1_2875 = torch.constant.int 1 + %2557 = torch.aten.slice.Tensor %2556, %int1_2872, %int0_2873, %int9223372036854775807_2874, %int1_2875 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2557, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_2876 = torch.constant.int 0 + %2558 = torch.aten.unsqueeze %2555, %int0_2876 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2558, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int1_2877 = torch.constant.int 1 %int0_2878 = torch.constant.int 0 %int9223372036854775807_2879 = torch.constant.int 9223372036854775807 %int1_2880 = torch.constant.int 1 - %2546 = torch.aten.slice.Tensor %2545, %int1_2877, %int0_2878, %int9223372036854775807_2879, %int1_2880 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2546, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_2881 = torch.constant.int 0 - %2547 = torch.aten.unsqueeze %2546, %int0_2881 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2547, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_2882 = torch.constant.int 1 + %2559 = torch.aten.slice.Tensor %2558, %int1_2877, %int0_2878, %int9223372036854775807_2879, %int1_2880 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2559, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_2881 = torch.constant.int 2 + %2560 = torch.aten.unsqueeze %2559, %int2_2881 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2560, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_2882 = torch.constant.int 3 %int0_2883 = torch.constant.int 0 %int9223372036854775807_2884 = torch.constant.int 9223372036854775807 %int1_2885 = torch.constant.int 1 - %2548 = torch.aten.slice.Tensor %2547, %int1_2882, %int0_2883, %int9223372036854775807_2884, %int1_2885 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2548, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_2886 = torch.constant.int 2 - %int0_2887 = torch.constant.int 0 - %int9223372036854775807_2888 = torch.constant.int 9223372036854775807 + %2561 = torch.aten.slice.Tensor %2560, %int3_2882, %int0_2883, %int9223372036854775807_2884, %int1_2885 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2561, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_2886 = torch.constant.int 4 + %int1_2887 = torch.constant.int 1 + %int1_2888 = torch.constant.int 1 %int1_2889 = torch.constant.int 1 - %2549 = torch.aten.slice.Tensor %2548, %int2_2886, %int0_2887, %int9223372036854775807_2888, %int1_2889 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2549, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_2890 = torch.constant.int 4 + %2562 = torch.prim.ListConstruct %int4_2886, %int1_2887, %int1_2888, %int1_2889 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2563 = torch.aten.repeat %2561, %2562 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2563, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_2890 = torch.constant.int 0 + %2564 = torch.aten.unsqueeze %2557, %int0_2890 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2564, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int1_2891 = torch.constant.int 1 - %int1_2892 = torch.constant.int 1 - %2550 = torch.prim.ListConstruct %int4_2890, %int1_2891, %int1_2892 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2551 = torch.aten.repeat %2549, %2550 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %2551, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_2893 = torch.constant.int 6 - %2552 = torch.prims.convert_element_type %2499, %int6_2893 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %2552, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %2553 = torch_c.to_builtin_tensor %2552 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %2554 = torch_c.to_builtin_tensor %2551 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %2555 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%2553, %2554) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %2556 = torch_c.from_builtin_tensor %2555 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %2556, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_2894 = torch.constant.int 5 - %2557 = torch.prims.convert_element_type %2556, %int5_2894 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2557, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_2895 = torch.constant.int 64 - %2558 = torch.aten.mul.Scalar %arg2, %int64_2895 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2558, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int22 = torch.constant.int 22 - %int1_2896 = torch.constant.int 1 - %2559 = torch.aten.add.Scalar %2558, %int22, %int1_2896 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2559, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_2897 = torch.constant.int 4 - %int32_2898 = torch.constant.int 32 - %int8_2899 = torch.constant.int 8 - %int128_2900 = torch.constant.int 128 - %2560 = torch.prim.ListConstruct %int4_2897, %398, %int32_2898, %int8_2899, %int128_2900 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2561 = torch.aten.view %2557, %2560 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2561, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_2901 = torch.constant.int 4 - %2562 = torch.aten.mul.int %int4_2901, %398 : !torch.int, !torch.int -> !torch.int - %int32_2902 = torch.constant.int 32 - %int8_2903 = torch.constant.int 8 - %int128_2904 = torch.constant.int 128 - %2563 = torch.prim.ListConstruct %2562, %int32_2902, %int8_2903, %int128_2904 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2564 = torch.aten.view %2561, %2563 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2564, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_2905 = torch.constant.int 4 - %2565 = torch.aten.mul.int %int4_2905, %398 : !torch.int, !torch.int -> !torch.int - %2566 = torch.prim.ListConstruct %2565 : (!torch.int) -> !torch.list - %2567 = torch.aten.view %2559, %2566 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2567, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_2906 = torch.constant.int 32 - %int2_2907 = torch.constant.int 2 - %int32_2908 = torch.constant.int 32 - %int8_2909 = torch.constant.int 8 - %int128_2910 = torch.constant.int 128 - %2568 = torch.prim.ListConstruct %389, %int32_2906, %int2_2907, %int32_2908, %int8_2909, %int128_2910 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2569 = torch.aten.view %2401, %2568 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2569, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2911 = torch.constant.int 32 - %2570 = torch.aten.mul.int %389, %int32_2911 : !torch.int, !torch.int -> !torch.int - %int2_2912 = torch.constant.int 2 - %2571 = torch.aten.mul.int %2570, %int2_2912 : !torch.int, !torch.int -> !torch.int - %int32_2913 = torch.constant.int 32 - %int8_2914 = torch.constant.int 8 - %int128_2915 = torch.constant.int 128 - %2572 = torch.prim.ListConstruct %2571, %int32_2913, %int8_2914, %int128_2915 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2573 = torch.aten.view %2569, %2572 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2573, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %2574 = torch.prim.ListConstruct %2567 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_2916 = torch.constant.bool false - %2575 = torch.aten.index_put %2573, %2574, %2564, %false_2916 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2575, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_2917 = torch.constant.int 32 - %int2_2918 = torch.constant.int 2 - %int32_2919 = torch.constant.int 32 - %int8_2920 = torch.constant.int 8 - %int128_2921 = torch.constant.int 128 - %2576 = torch.prim.ListConstruct %389, %int32_2917, %int2_2918, %int32_2919, %int8_2920, %int128_2921 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2577 = torch.aten.view %2575, %2576 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2577, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2922 = torch.constant.int 2097152 - %2578 = torch.prim.ListConstruct %389, %int2097152_2922 : (!torch.int, !torch.int) -> !torch.list - %2579 = torch.aten.view %2577, %2578 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2579, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_2923 = torch.constant.int 32 - %int2_2924 = torch.constant.int 2 - %int32_2925 = torch.constant.int 32 - %int8_2926 = torch.constant.int 8 + %int0_2892 = torch.constant.int 0 + %int9223372036854775807_2893 = torch.constant.int 9223372036854775807 + %int1_2894 = torch.constant.int 1 + %2565 = torch.aten.slice.Tensor %2564, %int1_2891, %int0_2892, %int9223372036854775807_2893, %int1_2894 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2565, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_2895 = torch.constant.int 2 + %2566 = torch.aten.unsqueeze %2565, %int2_2895 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2566, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_2896 = torch.constant.int 3 + %int0_2897 = torch.constant.int 0 + %int9223372036854775807_2898 = torch.constant.int 9223372036854775807 + %int1_2899 = torch.constant.int 1 + %2567 = torch.aten.slice.Tensor %2566, %int3_2896, %int0_2897, %int9223372036854775807_2898, %int1_2899 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2567, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_2900 = torch.constant.int 4 + %int1_2901 = torch.constant.int 1 + %int1_2902 = torch.constant.int 1 + %int1_2903 = torch.constant.int 1 + %2568 = torch.prim.ListConstruct %int4_2900, %int1_2901, %int1_2902, %int1_2903 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2569 = torch.aten.repeat %2567, %2568 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2569, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %2570 = torch.aten.mul.Tensor %2510, %2563 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2570, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_2904 = torch.constant.int 3 + %int0_2905 = torch.constant.int 0 + %int64_2906 = torch.constant.int 64 + %int1_2907 = torch.constant.int 1 + %2571 = torch.aten.slice.Tensor %2510, %int3_2904, %int0_2905, %int64_2906, %int1_2907 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %2571, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_2908 = torch.constant.int 3 + %int64_2909 = torch.constant.int 64 + %int9223372036854775807_2910 = torch.constant.int 9223372036854775807 + %int1_2911 = torch.constant.int 1 + %2572 = torch.aten.slice.Tensor %2510, %int3_2908, %int64_2909, %int9223372036854775807_2910, %int1_2911 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %2572, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %2573 = torch.aten.neg %2572 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %2573, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %2574 = torch.prim.ListConstruct %2573, %2571 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_2912 = torch.constant.int -1 + %2575 = torch.aten.cat %2574, %int-1_2912 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2575, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %2576 = torch.aten.mul.Tensor %2575, %2569 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2576, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_2913 = torch.constant.int 1 + %2577 = torch.aten.add.Tensor %2570, %2576, %int1_2913 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2577, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_2914 = torch.constant.int 131072 + %none_2915 = torch.constant.none + %none_2916 = torch.constant.none + %cpu_2917 = torch.constant.device "cpu" + %false_2918 = torch.constant.bool false + %2578 = torch.aten.arange %int131072_2914, %none_2915, %none_2916, %cpu_2917, %false_2918 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_2919 = torch.constant.int 0 + %int128_2920 = torch.constant.int 128 + %int2_2921 = torch.constant.int 2 + %int4_2922 = torch.constant.int 4 + %none_2923 = torch.constant.none + %cpu_2924 = torch.constant.device "cpu" + %false_2925 = torch.constant.bool false + %2579 = torch.aten.arange.start_step %int0_2919, %int128_2920, %int2_2921, %int4_2922, %none_2923, %cpu_2924, %false_2925 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_2926 = torch.constant.int 6 + %2580 = torch.prims.convert_element_type %2579, %int6_2926 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> %int128_2927 = torch.constant.int 128 - %2580 = torch.prim.ListConstruct %389, %int32_2923, %int2_2924, %int32_2925, %int8_2926, %int128_2927 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2581 = torch.aten.view %2579, %2580 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2581, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2928 = torch.constant.int 32 - %int8_2929 = torch.constant.int 8 - %int128_2930 = torch.constant.int 128 - %2582 = torch.prim.ListConstruct %2571, %int32_2928, %int8_2929, %int128_2930 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2583 = torch.aten.view %2581, %2582 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2583, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_2931 = torch.constant.int 4 - %int32_2932 = torch.constant.int 32 - %int8_2933 = torch.constant.int 8 - %int128_2934 = torch.constant.int 128 - %2584 = torch.prim.ListConstruct %int4_2931, %398, %int32_2932, %int8_2933, %int128_2934 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2585 = torch.aten.view %2501, %2584 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2585, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_2935 = torch.constant.int 4 - %2586 = torch.aten.mul.int %int4_2935, %398 : !torch.int, !torch.int -> !torch.int - %int32_2936 = torch.constant.int 32 - %int8_2937 = torch.constant.int 8 - %int128_2938 = torch.constant.int 128 - %2587 = torch.prim.ListConstruct %2586, %int32_2936, %int8_2937, %int128_2938 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2588 = torch.aten.view %2585, %2587 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2588, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_2939 = torch.constant.int 1 + %2581 = torch.aten.div.Scalar %2580, %int128_2927 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_2928 = torch.constant.float 5.000000e+05 + %2582 = torch.aten.pow.Scalar %float5.000000e05_2928, %2581 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2583 = torch.aten.reciprocal %2582 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_2929 = torch.constant.float 1.000000e+00 + %2584 = torch.aten.mul.Scalar %2583, %float1.000000e00_2929 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %2585 = torch.aten.reciprocal %2584 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_2930 = torch.constant.float 6.2831853071795862 + %2586 = torch.aten.mul.Scalar %2585, %float6.283190e00_2930 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_2931 = torch.constant.float 8.192000e+03 + %2587 = torch.aten.gt.Scalar %2586, %float8.192000e03_2931 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_2932 = torch.constant.int 8 + %2588 = torch.aten.div.Scalar %2584, %int8_2932 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %2589 = torch.aten.where.self %2587, %2588, %2584 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2590 = torch.aten.reciprocal %2586 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_2933 = torch.constant.int 8192 + %2591 = torch.aten.mul.Scalar %2590, %int8192_2933 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_2934 = torch.constant.int 1 + %int1_2935 = torch.constant.int 1 + %2592 = torch.aten.sub.Scalar %2591, %int1_2934, %int1_2935 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_2936 = torch.constant.int 3 + %2593 = torch.aten.div.Scalar %2592, %int3_2936 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_2937 = torch.constant.int 1 + %int1_2938 = torch.constant.int 1 + %2594 = torch.aten.rsub.Scalar %2593, %int1_2937, %int1_2938 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %2595 = torch.aten.mul.Tensor %2594, %2589 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_2939 = torch.constant.int 8 + %2596 = torch.aten.div.Scalar %2595, %int8_2939 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %2597 = torch.aten.mul.Tensor %2593, %2589 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> %int1_2940 = torch.constant.int 1 - %2589 = torch.aten.add.Scalar %2559, %int1_2939, %int1_2940 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2589, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_2941 = torch.constant.int 4 - %2590 = torch.aten.mul.int %int4_2941, %398 : !torch.int, !torch.int -> !torch.int - %2591 = torch.prim.ListConstruct %2590 : (!torch.int) -> !torch.list - %2592 = torch.aten.view %2589, %2591 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2592, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %2593 = torch.prim.ListConstruct %2592 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_2942 = torch.constant.bool false - %2594 = torch.aten.index_put %2583, %2593, %2588, %false_2942 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2594, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_2943 = torch.constant.int 32 - %int2_2944 = torch.constant.int 2 - %int32_2945 = torch.constant.int 32 - %int8_2946 = torch.constant.int 8 - %int128_2947 = torch.constant.int 128 - %2595 = torch.prim.ListConstruct %389, %int32_2943, %int2_2944, %int32_2945, %int8_2946, %int128_2947 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2596 = torch.aten.view %2594, %2595 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2596, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2948 = torch.constant.int 2097152 - %2597 = torch.prim.ListConstruct %389, %int2097152_2948 : (!torch.int, !torch.int) -> !torch.list - %2598 = torch.aten.view %2596, %2597 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2598, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_2949 = torch.constant.int -2 - %2599 = torch.aten.unsqueeze %2557, %int-2_2949 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2599, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_2950 = torch.constant.int 4 - %int8_2951 = torch.constant.int 8 - %int4_2952 = torch.constant.int 4 - %int128_2953 = torch.constant.int 128 - %2600 = torch.prim.ListConstruct %int4_2950, %2542, %int8_2951, %int4_2952, %int128_2953 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2954 = torch.constant.bool false - %2601 = torch.aten.expand %2599, %2600, %false_2954 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2601, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %2598 = torch.aten.add.Tensor %2596, %2597, %int1_2940 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_2941 = torch.constant.float 2.048000e+03 + %2599 = torch.aten.lt.Scalar %2586, %float2.048000e03_2941 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2600 = torch.aten.bitwise_not %2599 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_2942 = torch.constant.float 8.192000e+03 + %2601 = torch.aten.gt.Scalar %2586, %float8.192000e03_2942 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2602 = torch.aten.bitwise_not %2601 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2603 = torch.aten.mul.Tensor %2600, %2602 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2604 = torch.aten.where.self %2603, %2598, %2589 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2605 = torch.prim.ListConstruct %2604, %2604 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_2943 = torch.constant.int -1 + %2606 = torch.aten.cat %2605, %int-1_2943 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_2944 = torch.constant.int 6 + %2607 = torch.prims.convert_element_type %2606, %int6_2944 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_2945 = torch.constant.int 1 + %2608 = torch.aten.unsqueeze %2578, %int1_2945 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_2946 = torch.constant.int 6 + %2609 = torch.prims.convert_element_type %2608, %int6_2946 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_2947 = torch.constant.int 0 + %2610 = torch.aten.unsqueeze %2607, %int0_2947 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_2948 = torch.constant.int 6 + %2611 = torch.prims.convert_element_type %2610, %int6_2948 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %2612 = torch.aten.mul.Tensor %2609, %2611 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %2613 = torch.aten.cos %2612 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_2949 = torch.constant.int 5 + %2614 = torch.prims.convert_element_type %2613, %int5_2949 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %2615 = torch.aten.sin %2612 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_2950 = torch.constant.int 5 + %2616 = torch.prims.convert_element_type %2615, %int5_2950 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_2951 = torch.constant.int 0 + %int0_2952 = torch.constant.int 0 + %int1_2953 = torch.constant.int 1 + %2617 = torch.aten.slice.Tensor %2614, %int0_2951, %int0_2952, %298, %int1_2953 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2617, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_2954 = torch.constant.int 1 %int0_2955 = torch.constant.int 0 - %2602 = torch.aten.clone %2601, %int0_2955 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2602, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_2956 = torch.constant.int 4 - %int32_2957 = torch.constant.int 32 - %int128_2958 = torch.constant.int 128 - %2603 = torch.prim.ListConstruct %int4_2956, %2542, %int32_2957, %int128_2958 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2604 = torch.aten._unsafe_view %2602, %2603 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2604, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_2959 = torch.constant.int -2 - %2605 = torch.aten.unsqueeze %2501, %int-2_2959 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2605, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int9223372036854775807_2956 = torch.constant.int 9223372036854775807 + %int1_2957 = torch.constant.int 1 + %2618 = torch.aten.slice.Tensor %2617, %int1_2954, %int0_2955, %int9223372036854775807_2956, %int1_2957 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2618, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_2958 = torch.constant.int 0 + %int0_2959 = torch.constant.int 0 %int1_2960 = torch.constant.int 1 - %2606 = torch.aten.size.int %2495, %int1_2960 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_2961 = torch.constant.int 4 - %int8_2962 = torch.constant.int 8 - %int4_2963 = torch.constant.int 4 - %int128_2964 = torch.constant.int 128 - %2607 = torch.prim.ListConstruct %int4_2961, %2606, %int8_2962, %int4_2963, %int128_2964 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2965 = torch.constant.bool false - %2608 = torch.aten.expand %2605, %2607, %false_2965 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2608, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_2966 = torch.constant.int 0 - %2609 = torch.aten.clone %2608, %int0_2966 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2609, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_2967 = torch.constant.int 4 - %int32_2968 = torch.constant.int 32 - %int128_2969 = torch.constant.int 128 - %2610 = torch.prim.ListConstruct %int4_2967, %2606, %int32_2968, %int128_2969 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2611 = torch.aten._unsafe_view %2609, %2610 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2611, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_2970 = torch.constant.int 1 - %int2_2971 = torch.constant.int 2 - %2612 = torch.aten.transpose.int %2529, %int1_2970, %int2_2971 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2612, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_2972 = torch.constant.int 1 - %int2_2973 = torch.constant.int 2 - %2613 = torch.aten.transpose.int %2604, %int1_2972, %int2_2973 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2613, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %2619 = torch.aten.slice.Tensor %2616, %int0_2958, %int0_2959, %298, %int1_2960 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2619, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_2961 = torch.constant.int 1 + %int0_2962 = torch.constant.int 0 + %int9223372036854775807_2963 = torch.constant.int 9223372036854775807 + %int1_2964 = torch.constant.int 1 + %2620 = torch.aten.slice.Tensor %2619, %int1_2961, %int0_2962, %int9223372036854775807_2963, %int1_2964 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2620, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_2965 = torch.constant.int 0 + %2621 = torch.aten.unsqueeze %2618, %int0_2965 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2621, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_2966 = torch.constant.int 1 + %int0_2967 = torch.constant.int 0 + %int9223372036854775807_2968 = torch.constant.int 9223372036854775807 + %int1_2969 = torch.constant.int 1 + %2622 = torch.aten.slice.Tensor %2621, %int1_2966, %int0_2967, %int9223372036854775807_2968, %int1_2969 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2622, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_2970 = torch.constant.int 2 + %2623 = torch.aten.unsqueeze %2622, %int2_2970 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2623, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_2971 = torch.constant.int 3 + %int0_2972 = torch.constant.int 0 + %int9223372036854775807_2973 = torch.constant.int 9223372036854775807 %int1_2974 = torch.constant.int 1 - %int2_2975 = torch.constant.int 2 - %2614 = torch.aten.transpose.int %2611, %int1_2974, %int2_2975 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2614, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_2976 = torch.constant.float 0.000000e+00 - %true_2977 = torch.constant.bool true - %none_2978 = torch.constant.none - %none_2979 = torch.constant.none - %2615:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2612, %2613, %2614, %float0.000000e00_2976, %true_2977, %none_2978, %none_2979) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %2615#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %2624 = torch.aten.slice.Tensor %2623, %int3_2971, %int0_2972, %int9223372036854775807_2973, %int1_2974 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2624, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_2975 = torch.constant.int 4 + %int1_2976 = torch.constant.int 1 + %int1_2977 = torch.constant.int 1 + %int1_2978 = torch.constant.int 1 + %2625 = torch.prim.ListConstruct %int4_2975, %int1_2976, %int1_2977, %int1_2978 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2626 = torch.aten.repeat %2624, %2625 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2626, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_2979 = torch.constant.int 0 + %2627 = torch.aten.unsqueeze %2620, %int0_2979 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2627, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int1_2980 = torch.constant.int 1 - %int2_2981 = torch.constant.int 2 - %2616 = torch.aten.transpose.int %2615#0, %int1_2980, %int2_2981 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2616, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_2982 = torch.constant.int 4 - %int4096_2983 = torch.constant.int 4096 - %2617 = torch.prim.ListConstruct %int4_2982, %2514, %int4096_2983 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2618 = torch.aten.view %2616, %2617 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2618, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_2984 = torch.constant.int -2 - %int-1_2985 = torch.constant.int -1 - %2619 = torch.aten.transpose.int %104, %int-2_2984, %int-1_2985 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_2986 = torch.constant.int 4 - %2620 = torch.aten.mul.int %int4_2986, %2514 : !torch.int, !torch.int -> !torch.int - %int4096_2987 = torch.constant.int 4096 - %2621 = torch.prim.ListConstruct %2620, %int4096_2987 : (!torch.int, !torch.int) -> !torch.list - %2622 = torch.aten.view %2618, %2621 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2622, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2623 = torch.aten.mm %2622, %2619 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2623, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_2988 = torch.constant.int 4 - %int4096_2989 = torch.constant.int 4096 - %2624 = torch.prim.ListConstruct %int4_2988, %2514, %int4096_2989 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2625 = torch.aten.view %2623, %2624 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2625, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int0_2981 = torch.constant.int 0 + %int9223372036854775807_2982 = torch.constant.int 9223372036854775807 + %int1_2983 = torch.constant.int 1 + %2628 = torch.aten.slice.Tensor %2627, %int1_2980, %int0_2981, %int9223372036854775807_2982, %int1_2983 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2628, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_2984 = torch.constant.int 2 + %2629 = torch.aten.unsqueeze %2628, %int2_2984 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2629, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_2985 = torch.constant.int 3 + %int0_2986 = torch.constant.int 0 + %int9223372036854775807_2987 = torch.constant.int 9223372036854775807 + %int1_2988 = torch.constant.int 1 + %2630 = torch.aten.slice.Tensor %2629, %int3_2985, %int0_2986, %int9223372036854775807_2987, %int1_2988 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2630, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_2989 = torch.constant.int 4 %int1_2990 = torch.constant.int 1 - %2626 = torch.aten.add.Tensor %2464, %2625, %int1_2990 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2626, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_2991 = torch.constant.int 6 - %2627 = torch.prims.convert_element_type %2626, %int6_2991 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2627, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_2992 = torch.constant.int 2 - %2628 = torch.aten.pow.Tensor_Scalar %2627, %int2_2992 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2628, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_2993 = torch.constant.int -1 - %2629 = torch.prim.ListConstruct %int-1_2993 : (!torch.int) -> !torch.list - %true_2994 = torch.constant.bool true - %none_2995 = torch.constant.none - %2630 = torch.aten.mean.dim %2628, %2629, %true_2994, %none_2995 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2630, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_2996 = torch.constant.float 9.9999997473787516E-6 - %int1_2997 = torch.constant.int 1 - %2631 = torch.aten.add.Scalar %2630, %float9.999990e-06_2996, %int1_2997 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2631, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2632 = torch.aten.rsqrt %2631 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2632, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2633 = torch.aten.mul.Tensor %2627, %2632 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2633, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2998 = torch.constant.int 5 - %2634 = torch.prims.convert_element_type %2633, %int5_2998 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2634, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %2635 = torch.aten.mul.Tensor %105, %2634 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2635, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_2999 = torch.constant.int 5 - %2636 = torch.prims.convert_element_type %2635, %int5_2999 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2636, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3000 = torch.constant.int -2 + %int1_2991 = torch.constant.int 1 + %int1_2992 = torch.constant.int 1 + %2631 = torch.prim.ListConstruct %int4_2989, %int1_2990, %int1_2991, %int1_2992 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2632 = torch.aten.repeat %2630, %2631 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2632, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %2633 = torch.aten.mul.Tensor %2512, %2626 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2633, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_2993 = torch.constant.int 3 + %int0_2994 = torch.constant.int 0 + %int64_2995 = torch.constant.int 64 + %int1_2996 = torch.constant.int 1 + %2634 = torch.aten.slice.Tensor %2512, %int3_2993, %int0_2994, %int64_2995, %int1_2996 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %2634, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_2997 = torch.constant.int 3 + %int64_2998 = torch.constant.int 64 + %int9223372036854775807_2999 = torch.constant.int 9223372036854775807 + %int1_3000 = torch.constant.int 1 + %2635 = torch.aten.slice.Tensor %2512, %int3_2997, %int64_2998, %int9223372036854775807_2999, %int1_3000 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %2635, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %2636 = torch.aten.neg %2635 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %2636, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %2637 = torch.prim.ListConstruct %2636, %2634 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list %int-1_3001 = torch.constant.int -1 - %2637 = torch.aten.transpose.int %106, %int-2_3000, %int-1_3001 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3002 = torch.constant.int 4 - %2638 = torch.aten.mul.int %int4_3002, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3003 = torch.constant.int 4096 - %2639 = torch.prim.ListConstruct %2638, %int4096_3003 : (!torch.int, !torch.int) -> !torch.list - %2640 = torch.aten.view %2636, %2639 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2640, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2641 = torch.aten.mm %2640, %2637 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2641, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_3004 = torch.constant.int 4 - %int14336_3005 = torch.constant.int 14336 - %2642 = torch.prim.ListConstruct %int4_3004, %306, %int14336_3005 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2643 = torch.aten.view %2641, %2642 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2643, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %2644 = torch.aten.silu %2643 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2644, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_3006 = torch.constant.int -2 - %int-1_3007 = torch.constant.int -1 - %2645 = torch.aten.transpose.int %107, %int-2_3006, %int-1_3007 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3008 = torch.constant.int 4 - %2646 = torch.aten.mul.int %int4_3008, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3009 = torch.constant.int 4096 - %2647 = torch.prim.ListConstruct %2646, %int4096_3009 : (!torch.int, !torch.int) -> !torch.list - %2648 = torch.aten.view %2636, %2647 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2648, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2649 = torch.aten.mm %2648, %2645 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2649, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_3010 = torch.constant.int 4 - %int14336_3011 = torch.constant.int 14336 - %2650 = torch.prim.ListConstruct %int4_3010, %306, %int14336_3011 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2651 = torch.aten.view %2649, %2650 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2651, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %2652 = torch.aten.mul.Tensor %2644, %2651 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2652, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_3012 = torch.constant.int -2 - %int-1_3013 = torch.constant.int -1 - %2653 = torch.aten.transpose.int %108, %int-2_3012, %int-1_3013 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_3014 = torch.constant.int 1 - %2654 = torch.aten.size.int %2643, %int1_3014 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_3015 = torch.constant.int 4 - %2655 = torch.aten.mul.int %int4_3015, %2654 : !torch.int, !torch.int -> !torch.int - %int14336_3016 = torch.constant.int 14336 - %2656 = torch.prim.ListConstruct %2655, %int14336_3016 : (!torch.int, !torch.int) -> !torch.list - %2657 = torch.aten.view %2652, %2656 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2657, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %2658 = torch.aten.mm %2657, %2653 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2658, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_3017 = torch.constant.int 4 - %int4096_3018 = torch.constant.int 4096 - %2659 = torch.prim.ListConstruct %int4_3017, %2654, %int4096_3018 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2660 = torch.aten.view %2658, %2659 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2660, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_3019 = torch.constant.int 1 - %2661 = torch.aten.add.Tensor %2626, %2660, %int1_3019 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2661, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_3020 = torch.constant.int 6 - %2662 = torch.prims.convert_element_type %2661, %int6_3020 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2662, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_3021 = torch.constant.int 2 - %2663 = torch.aten.pow.Tensor_Scalar %2662, %int2_3021 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2663, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_3022 = torch.constant.int -1 - %2664 = torch.prim.ListConstruct %int-1_3022 : (!torch.int) -> !torch.list - %true_3023 = torch.constant.bool true - %none_3024 = torch.constant.none - %2665 = torch.aten.mean.dim %2663, %2664, %true_3023, %none_3024 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2665, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_3025 = torch.constant.float 9.9999997473787516E-6 - %int1_3026 = torch.constant.int 1 - %2666 = torch.aten.add.Scalar %2665, %float9.999990e-06_3025, %int1_3026 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2666, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2667 = torch.aten.rsqrt %2666 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2667, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2668 = torch.aten.mul.Tensor %2662, %2667 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2668, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_3027 = torch.constant.int 5 - %2669 = torch.prims.convert_element_type %2668, %int5_3027 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2669, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %2670 = torch.aten.mul.Tensor %109, %2669 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2670, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_3028 = torch.constant.int 5 - %2671 = torch.prims.convert_element_type %2670, %int5_3028 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2671, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3029 = torch.constant.int -2 - %int-1_3030 = torch.constant.int -1 - %2672 = torch.aten.transpose.int %110, %int-2_3029, %int-1_3030 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3031 = torch.constant.int 4 - %2673 = torch.aten.mul.int %int4_3031, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3032 = torch.constant.int 4096 - %2674 = torch.prim.ListConstruct %2673, %int4096_3032 : (!torch.int, !torch.int) -> !torch.list - %2675 = torch.aten.view %2671, %2674 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2675, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2676 = torch.aten.mm %2675, %2672 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2676, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_3033 = torch.constant.int 4 - %int4096_3034 = torch.constant.int 4096 - %2677 = torch.prim.ListConstruct %int4_3033, %306, %int4096_3034 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2678 = torch.aten.view %2676, %2677 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2678, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3035 = torch.constant.int -2 - %int-1_3036 = torch.constant.int -1 - %2679 = torch.aten.transpose.int %111, %int-2_3035, %int-1_3036 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3037 = torch.constant.int 4 - %2680 = torch.aten.mul.int %int4_3037, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3038 = torch.constant.int 4096 - %2681 = torch.prim.ListConstruct %2680, %int4096_3038 : (!torch.int, !torch.int) -> !torch.list - %2682 = torch.aten.view %2671, %2681 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2682, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2683 = torch.aten.mm %2682, %2679 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %2683, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_3039 = torch.constant.int 4 - %int1024_3040 = torch.constant.int 1024 - %2684 = torch.prim.ListConstruct %int4_3039, %306, %int1024_3040 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2685 = torch.aten.view %2683, %2684 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %2685, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_3041 = torch.constant.int -2 - %int-1_3042 = torch.constant.int -1 - %2686 = torch.aten.transpose.int %112, %int-2_3041, %int-1_3042 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3043 = torch.constant.int 4 - %2687 = torch.aten.mul.int %int4_3043, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3044 = torch.constant.int 4096 - %2688 = torch.prim.ListConstruct %2687, %int4096_3044 : (!torch.int, !torch.int) -> !torch.list - %2689 = torch.aten.view %2671, %2688 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2689, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2690 = torch.aten.mm %2689, %2686 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %2690, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_3045 = torch.constant.int 4 - %int1024_3046 = torch.constant.int 1024 - %2691 = torch.prim.ListConstruct %int4_3045, %306, %int1024_3046 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2692 = torch.aten.view %2690, %2691 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %2692, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_3047 = torch.constant.int 4 - %int32_3048 = torch.constant.int 32 - %int128_3049 = torch.constant.int 128 - %2693 = torch.prim.ListConstruct %int4_3047, %306, %int32_3048, %int128_3049 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2694 = torch.aten.view %2678, %2693 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2694, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_3050 = torch.constant.int 4 - %int8_3051 = torch.constant.int 8 - %int128_3052 = torch.constant.int 128 - %2695 = torch.prim.ListConstruct %int4_3050, %306, %int8_3051, %int128_3052 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2696 = torch.aten.view %2685, %2695 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2696, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_3053 = torch.constant.int 4 - %int8_3054 = torch.constant.int 8 - %int128_3055 = torch.constant.int 128 - %2697 = torch.prim.ListConstruct %int4_3053, %306, %int8_3054, %int128_3055 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2698 = torch.aten.view %2692, %2697 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2698, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_3056 = torch.constant.int 131072 - %none_3057 = torch.constant.none - %none_3058 = torch.constant.none - %cpu_3059 = torch.constant.device "cpu" - %false_3060 = torch.constant.bool false - %2699 = torch.aten.arange %int131072_3056, %none_3057, %none_3058, %cpu_3059, %false_3060 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_3061 = torch.constant.int 0 - %int128_3062 = torch.constant.int 128 - %none_3063 = torch.constant.none - %none_3064 = torch.constant.none - %cpu_3065 = torch.constant.device "cpu" - %false_3066 = torch.constant.bool false - %2700 = torch.aten.arange.start %int0_3061, %int128_3062, %none_3063, %none_3064, %cpu_3065, %false_3066 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_3067 = torch.constant.int 2 - %2701 = torch.aten.floor_divide.Scalar %2700, %int2_3067 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_3068 = torch.constant.int 6 - %2702 = torch.prims.convert_element_type %2701, %int6_3068 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> + %2638 = torch.aten.cat %2637, %int-1_3001 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2638, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %2639 = torch.aten.mul.Tensor %2638, %2632 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2639, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_3002 = torch.constant.int 1 + %2640 = torch.aten.add.Tensor %2633, %2639, %int1_3002 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2640, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_3003 = torch.constant.int 32 + %2641 = torch.aten.mul.Scalar %arg2, %int32_3003 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2641, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int8_3004 = torch.constant.int 8 + %int1_3005 = torch.constant.int 1 + %2642 = torch.aten.add.Scalar %2641, %int8_3004, %int1_3005 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2642, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_3006 = torch.constant.int 2 + %2643 = torch.aten.mul.Scalar %2642, %int2_3006 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2643, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_3007 = torch.constant.int 0 + %int1_3008 = torch.constant.int 1 + %2644 = torch.aten.add.Scalar %2643, %int0_3007, %int1_3008 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2644, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %2645 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %2646 = torch.aten.view %2644, %2645 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %2646, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_3009 = torch.constant.int 4 + %int32_3010 = torch.constant.int 32 + %int8_3011 = torch.constant.int 8 + %int128_3012 = torch.constant.int 128 + %2647 = torch.prim.ListConstruct %int4_3009, %296, %int32_3010, %int8_3011, %int128_3012 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2648 = torch.aten.view %2640, %2647 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2648, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_3013 = torch.constant.int 32 + %int8_3014 = torch.constant.int 8 + %int128_3015 = torch.constant.int 128 + %2649 = torch.prim.ListConstruct %504, %int32_3013, %int8_3014, %int128_3015 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2650 = torch.aten.view %2648, %2649 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %2650, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_3016 = torch.constant.int 1 + %int2_3017 = torch.constant.int 2 + %2651 = torch.aten.transpose.int %2650, %int1_3016, %int2_3017 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2651, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_3018 = torch.constant.int 5 + %2652 = torch.prims.convert_element_type %2651, %int5_3018 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2652, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_3019 = torch.constant.int 32 + %int2_3020 = torch.constant.int 2 + %int8_3021 = torch.constant.int 8 + %int32_3022 = torch.constant.int 32 + %int128_3023 = torch.constant.int 128 + %2653 = torch.prim.ListConstruct %297, %int32_3019, %int2_3020, %int8_3021, %int32_3022, %int128_3023 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2654 = torch.aten.view %2416, %2653 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2654, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_3024 = torch.constant.int 8 + %int32_3025 = torch.constant.int 32 + %int128_3026 = torch.constant.int 128 + %2655 = torch.prim.ListConstruct %497, %int8_3024, %int32_3025, %int128_3026 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2656 = torch.aten.view %2654, %2655 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2656, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %2657 = torch.prim.ListConstruct %2646 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_3027 = torch.constant.bool false + %2658 = torch.aten.index_put %2656, %2657, %2652, %false_3027 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2658, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_3028 = torch.constant.int 32 + %int2_3029 = torch.constant.int 2 + %int8_3030 = torch.constant.int 8 + %int32_3031 = torch.constant.int 32 + %int128_3032 = torch.constant.int 128 + %2659 = torch.prim.ListConstruct %297, %int32_3028, %int2_3029, %int8_3030, %int32_3031, %int128_3032 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2660 = torch.aten.view %2658, %2659 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2660, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_3033 = torch.constant.int 2097152 + %2661 = torch.prim.ListConstruct %297, %int2097152_3033 : (!torch.int, !torch.int) -> !torch.list + %2662 = torch.aten.view %2660, %2661 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2662, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_3034 = torch.constant.int 32 + %int2_3035 = torch.constant.int 2 + %int8_3036 = torch.constant.int 8 + %int32_3037 = torch.constant.int 32 + %int128_3038 = torch.constant.int 128 + %2663 = torch.prim.ListConstruct %297, %int32_3034, %int2_3035, %int8_3036, %int32_3037, %int128_3038 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2664 = torch.aten.view %2662, %2663 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2664, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_3039 = torch.constant.int 8 + %int32_3040 = torch.constant.int 32 + %int128_3041 = torch.constant.int 128 + %2665 = torch.prim.ListConstruct %497, %int8_3039, %int32_3040, %int128_3041 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2666 = torch.aten.view %2664, %2665 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2666, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_3042 = torch.constant.int 32 + %2667 = torch.aten.mul.Scalar %arg2, %int32_3042 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2667, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int8_3043 = torch.constant.int 8 + %int1_3044 = torch.constant.int 1 + %2668 = torch.aten.add.Scalar %2667, %int8_3043, %int1_3044 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2668, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_3045 = torch.constant.int 2 + %2669 = torch.aten.mul.Scalar %2668, %int2_3045 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2669, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_3046 = torch.constant.int 1 + %int1_3047 = torch.constant.int 1 + %2670 = torch.aten.add.Scalar %2669, %int1_3046, %int1_3047 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2670, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %2671 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %2672 = torch.aten.view %2670, %2671 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %2672, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_3048 = torch.constant.int 4 + %int32_3049 = torch.constant.int 32 + %int8_3050 = torch.constant.int 8 + %int128_3051 = torch.constant.int 128 + %2673 = torch.prim.ListConstruct %int4_3048, %296, %int32_3049, %int8_3050, %int128_3051 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2674 = torch.aten.view %2514, %2673 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2674, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_3052 = torch.constant.int 32 + %int8_3053 = torch.constant.int 8 + %int128_3054 = torch.constant.int 128 + %2675 = torch.prim.ListConstruct %504, %int32_3052, %int8_3053, %int128_3054 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2676 = torch.aten.view %2674, %2675 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %2676, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_3055 = torch.constant.int 1 + %int2_3056 = torch.constant.int 2 + %2677 = torch.aten.transpose.int %2676, %int1_3055, %int2_3056 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2677, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_3057 = torch.constant.int 5 + %2678 = torch.prims.convert_element_type %2677, %int5_3057 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2678, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %2679 = torch.prim.ListConstruct %2672 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_3058 = torch.constant.bool false + %2680 = torch.aten.index_put %2666, %2679, %2678, %false_3058 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2680, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_3059 = torch.constant.int 32 + %int2_3060 = torch.constant.int 2 + %int8_3061 = torch.constant.int 8 + %int32_3062 = torch.constant.int 32 + %int128_3063 = torch.constant.int 128 + %2681 = torch.prim.ListConstruct %297, %int32_3059, %int2_3060, %int8_3061, %int32_3062, %int128_3063 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2682 = torch.aten.view %2680, %2681 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2682, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_3064 = torch.constant.int 2097152 + %2683 = torch.prim.ListConstruct %297, %int2097152_3064 : (!torch.int, !torch.int) -> !torch.list + %2684 = torch.aten.view %2682, %2683 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2684, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_3065 = torch.constant.int -2 + %2685 = torch.aten.unsqueeze %2640, %int-2_3065 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2685, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_3066 = torch.constant.int 4 + %int8_3067 = torch.constant.int 8 + %int4_3068 = torch.constant.int 4 %int128_3069 = torch.constant.int 128 - %2703 = torch.aten.div.Scalar %2702, %int128_3069 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_3070 = torch.constant.float 2.000000e+00 - %2704 = torch.aten.mul.Scalar %2703, %float2.000000e00_3070 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_3071 = torch.constant.float 5.000000e+05 - %2705 = torch.aten.pow.Scalar %float5.000000e05_3071, %2704 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %2706 = torch.aten.reciprocal %2705 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_3072 = torch.constant.float 1.000000e+00 - %2707 = torch.aten.mul.Scalar %2706, %float1.000000e00_3072 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_3073 = torch.constant.int 1 - %2708 = torch.aten.unsqueeze %2699, %int1_3073 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_3074 = torch.constant.int 0 - %2709 = torch.aten.unsqueeze %2707, %int0_3074 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %2710 = torch.aten.mul.Tensor %2708, %2709 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_3075 = torch.constant.int 1 - %2711 = torch.aten.size.int %2678, %int1_3075 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_3076 = torch.constant.int 0 - %2712 = torch.aten.add.int %int0_3076, %2711 : !torch.int, !torch.int -> !torch.int - %int0_3077 = torch.constant.int 0 - %int0_3078 = torch.constant.int 0 - %int1_3079 = torch.constant.int 1 - %2713 = torch.aten.slice.Tensor %2710, %int0_3077, %int0_3078, %2712, %int1_3079 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2713, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_3080 = torch.constant.int 1 + %2686 = torch.prim.ListConstruct %int4_3066, %298, %int8_3067, %int4_3068, %int128_3069 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_3070 = torch.constant.bool false + %2687 = torch.aten.expand %2685, %2686, %false_3070 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2687, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_3071 = torch.constant.int 0 + %2688 = torch.aten.clone %2687, %int0_3071 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2688, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_3072 = torch.constant.int 4 + %int32_3073 = torch.constant.int 32 + %int128_3074 = torch.constant.int 128 + %2689 = torch.prim.ListConstruct %int4_3072, %298, %int32_3073, %int128_3074 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2690 = torch.aten._unsafe_view %2688, %2689 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2690, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_3075 = torch.constant.int -2 + %2691 = torch.aten.unsqueeze %2514, %int-2_3075 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2691, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_3076 = torch.constant.int 4 + %int8_3077 = torch.constant.int 8 + %int4_3078 = torch.constant.int 4 + %int128_3079 = torch.constant.int 128 + %2692 = torch.prim.ListConstruct %int4_3076, %298, %int8_3077, %int4_3078, %int128_3079 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_3080 = torch.constant.bool false + %2693 = torch.aten.expand %2691, %2692, %false_3080 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2693, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int0_3081 = torch.constant.int 0 - %int9223372036854775807_3082 = torch.constant.int 9223372036854775807 - %int1_3083 = torch.constant.int 1 - %2714 = torch.aten.slice.Tensor %2713, %int1_3080, %int0_3081, %int9223372036854775807_3082, %int1_3083 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2714, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_3084 = torch.constant.int 1 - %int0_3085 = torch.constant.int 0 - %int9223372036854775807_3086 = torch.constant.int 9223372036854775807 + %2694 = torch.aten.clone %2693, %int0_3081 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2694, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_3082 = torch.constant.int 4 + %int32_3083 = torch.constant.int 32 + %int128_3084 = torch.constant.int 128 + %2695 = torch.prim.ListConstruct %int4_3082, %298, %int32_3083, %int128_3084 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2696 = torch.aten._unsafe_view %2694, %2695 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2696, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_3085 = torch.constant.int 1 + %int2_3086 = torch.constant.int 2 + %2697 = torch.aten.transpose.int %2577, %int1_3085, %int2_3086 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2697, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_3087 = torch.constant.int 1 - %2715 = torch.aten.slice.Tensor %2714, %int1_3084, %int0_3085, %int9223372036854775807_3086, %int1_3087 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2715, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_3088 = torch.constant.int 0 - %2716 = torch.aten.unsqueeze %2715, %int0_3088 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2716, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_3089 = torch.constant.int 1 - %int0_3090 = torch.constant.int 0 - %int9223372036854775807_3091 = torch.constant.int 9223372036854775807 - %int1_3092 = torch.constant.int 1 - %2717 = torch.aten.slice.Tensor %2716, %int1_3089, %int0_3090, %int9223372036854775807_3091, %int1_3092 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2717, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_3093 = torch.constant.int 2 - %int0_3094 = torch.constant.int 0 - %int9223372036854775807_3095 = torch.constant.int 9223372036854775807 - %int1_3096 = torch.constant.int 1 - %2718 = torch.aten.slice.Tensor %2717, %int2_3093, %int0_3094, %int9223372036854775807_3095, %int1_3096 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2718, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_3097 = torch.constant.int 4 - %int1_3098 = torch.constant.int 1 - %int1_3099 = torch.constant.int 1 - %2719 = torch.prim.ListConstruct %int4_3097, %int1_3098, %int1_3099 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2720 = torch.aten.repeat %2718, %2719 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %2720, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_3100 = torch.constant.int 6 - %2721 = torch.prims.convert_element_type %2694, %int6_3100 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %2721, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %2722 = torch_c.to_builtin_tensor %2721 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %2723 = torch_c.to_builtin_tensor %2720 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %2724 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%2722, %2723) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %2725 = torch_c.from_builtin_tensor %2724 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %2725, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_3101 = torch.constant.int 5 - %2726 = torch.prims.convert_element_type %2725, %int5_3101 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2726, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_3102 = torch.constant.int 131072 - %none_3103 = torch.constant.none - %none_3104 = torch.constant.none - %cpu_3105 = torch.constant.device "cpu" - %false_3106 = torch.constant.bool false - %2727 = torch.aten.arange %int131072_3102, %none_3103, %none_3104, %cpu_3105, %false_3106 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_3107 = torch.constant.int 0 - %int128_3108 = torch.constant.int 128 - %none_3109 = torch.constant.none - %none_3110 = torch.constant.none - %cpu_3111 = torch.constant.device "cpu" - %false_3112 = torch.constant.bool false - %2728 = torch.aten.arange.start %int0_3107, %int128_3108, %none_3109, %none_3110, %cpu_3111, %false_3112 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_3113 = torch.constant.int 2 - %2729 = torch.aten.floor_divide.Scalar %2728, %int2_3113 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_3114 = torch.constant.int 6 - %2730 = torch.prims.convert_element_type %2729, %int6_3114 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_3115 = torch.constant.int 128 - %2731 = torch.aten.div.Scalar %2730, %int128_3115 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_3116 = torch.constant.float 2.000000e+00 - %2732 = torch.aten.mul.Scalar %2731, %float2.000000e00_3116 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_3117 = torch.constant.float 5.000000e+05 - %2733 = torch.aten.pow.Scalar %float5.000000e05_3117, %2732 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %2734 = torch.aten.reciprocal %2733 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_3118 = torch.constant.float 1.000000e+00 - %2735 = torch.aten.mul.Scalar %2734, %float1.000000e00_3118 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_3119 = torch.constant.int 1 - %2736 = torch.aten.unsqueeze %2727, %int1_3119 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_3120 = torch.constant.int 0 - %2737 = torch.aten.unsqueeze %2735, %int0_3120 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %2738 = torch.aten.mul.Tensor %2736, %2737 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_3121 = torch.constant.int 1 - %2739 = torch.aten.size.int %2685, %int1_3121 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_3122 = torch.constant.int 0 - %2740 = torch.aten.add.int %int0_3122, %2739 : !torch.int, !torch.int -> !torch.int - %int0_3123 = torch.constant.int 0 - %int0_3124 = torch.constant.int 0 - %int1_3125 = torch.constant.int 1 - %2741 = torch.aten.slice.Tensor %2738, %int0_3123, %int0_3124, %2740, %int1_3125 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2741, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_3126 = torch.constant.int 1 - %int0_3127 = torch.constant.int 0 - %int9223372036854775807_3128 = torch.constant.int 9223372036854775807 - %int1_3129 = torch.constant.int 1 - %2742 = torch.aten.slice.Tensor %2741, %int1_3126, %int0_3127, %int9223372036854775807_3128, %int1_3129 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2742, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_3130 = torch.constant.int 1 - %int0_3131 = torch.constant.int 0 - %int9223372036854775807_3132 = torch.constant.int 9223372036854775807 - %int1_3133 = torch.constant.int 1 - %2743 = torch.aten.slice.Tensor %2742, %int1_3130, %int0_3131, %int9223372036854775807_3132, %int1_3133 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2743, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_3134 = torch.constant.int 0 - %2744 = torch.aten.unsqueeze %2743, %int0_3134 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2744, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_3135 = torch.constant.int 1 - %int0_3136 = torch.constant.int 0 - %int9223372036854775807_3137 = torch.constant.int 9223372036854775807 - %int1_3138 = torch.constant.int 1 - %2745 = torch.aten.slice.Tensor %2744, %int1_3135, %int0_3136, %int9223372036854775807_3137, %int1_3138 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2745, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_3139 = torch.constant.int 2 - %int0_3140 = torch.constant.int 0 - %int9223372036854775807_3141 = torch.constant.int 9223372036854775807 - %int1_3142 = torch.constant.int 1 - %2746 = torch.aten.slice.Tensor %2745, %int2_3139, %int0_3140, %int9223372036854775807_3141, %int1_3142 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2746, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_3143 = torch.constant.int 4 - %int1_3144 = torch.constant.int 1 - %int1_3145 = torch.constant.int 1 - %2747 = torch.prim.ListConstruct %int4_3143, %int1_3144, %int1_3145 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2748 = torch.aten.repeat %2746, %2747 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %2748, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_3146 = torch.constant.int 6 - %2749 = torch.prims.convert_element_type %2696, %int6_3146 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %2749, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %2750 = torch_c.to_builtin_tensor %2749 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %2751 = torch_c.to_builtin_tensor %2748 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %2752 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%2750, %2751) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %2753 = torch_c.from_builtin_tensor %2752 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %2753, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_3147 = torch.constant.int 5 - %2754 = torch.prims.convert_element_type %2753, %int5_3147 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2754, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_3148 = torch.constant.int 64 - %2755 = torch.aten.mul.Scalar %arg2, %int64_3148 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2755, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int24 = torch.constant.int 24 - %int1_3149 = torch.constant.int 1 - %2756 = torch.aten.add.Scalar %2755, %int24, %int1_3149 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2756, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_3150 = torch.constant.int 4 - %int32_3151 = torch.constant.int 32 - %int8_3152 = torch.constant.int 8 - %int128_3153 = torch.constant.int 128 - %2757 = torch.prim.ListConstruct %int4_3150, %398, %int32_3151, %int8_3152, %int128_3153 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2758 = torch.aten.view %2754, %2757 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2758, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_3154 = torch.constant.int 4 - %2759 = torch.aten.mul.int %int4_3154, %398 : !torch.int, !torch.int -> !torch.int - %int32_3155 = torch.constant.int 32 - %int8_3156 = torch.constant.int 8 - %int128_3157 = torch.constant.int 128 - %2760 = torch.prim.ListConstruct %2759, %int32_3155, %int8_3156, %int128_3157 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2761 = torch.aten.view %2758, %2760 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2761, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int2_3088 = torch.constant.int 2 + %2698 = torch.aten.transpose.int %2690, %int1_3087, %int2_3088 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2698, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_3089 = torch.constant.int 1 + %int2_3090 = torch.constant.int 2 + %2699 = torch.aten.transpose.int %2696, %int1_3089, %int2_3090 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2699, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_3091 = torch.constant.float 0.000000e+00 + %false_3092 = torch.constant.bool false + %none_3093 = torch.constant.none + %2700:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2697, %2698, %2699, %float0.000000e00_3091, %false_3092, %327, %none_3093) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %2700#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_3094 = torch.constant.int 1 + %int2_3095 = torch.constant.int 2 + %2701 = torch.aten.transpose.int %2700#0, %int1_3094, %int2_3095 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2701, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_3096 = torch.constant.int 4 + %int4096_3097 = torch.constant.int 4096 + %2702 = torch.prim.ListConstruct %int4_3096, %298, %int4096_3097 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2703 = torch.aten.view %2701, %2702 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2703, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_3098 = torch.constant.int -2 + %int-1_3099 = torch.constant.int -1 + %2704 = torch.aten.transpose.int %78, %int-2_3098, %int-1_3099 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_3100 = torch.constant.int 5 + %2705 = torch.prims.convert_element_type %2704, %int5_3100 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_3101 = torch.constant.int 4096 + %2706 = torch.prim.ListConstruct %342, %int4096_3101 : (!torch.int, !torch.int) -> !torch.list + %2707 = torch.aten.view %2703, %2706 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2707, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2708 = torch.aten.mm %2707, %2705 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2708, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_3102 = torch.constant.int 4 + %int4096_3103 = torch.constant.int 4096 + %2709 = torch.prim.ListConstruct %int4_3102, %298, %int4096_3103 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2710 = torch.aten.view %2708, %2709 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2710, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_3104 = torch.constant.int 1 + %2711 = torch.aten.add.Tensor %2477, %2710, %int1_3104 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2711, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_3105 = torch.constant.int 6 + %2712 = torch.prims.convert_element_type %2711, %int6_3105 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2712, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_3106 = torch.constant.int 2 + %2713 = torch.aten.pow.Tensor_Scalar %2712, %int2_3106 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2713, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_3107 = torch.constant.int -1 + %2714 = torch.prim.ListConstruct %int-1_3107 : (!torch.int) -> !torch.list + %true_3108 = torch.constant.bool true + %none_3109 = torch.constant.none + %2715 = torch.aten.mean.dim %2713, %2714, %true_3108, %none_3109 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2715, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_3110 = torch.constant.float 9.9999997473787516E-6 + %int1_3111 = torch.constant.int 1 + %2716 = torch.aten.add.Scalar %2715, %float9.999990e-06_3110, %int1_3111 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2716, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %2717 = torch.aten.rsqrt %2716 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2717, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %2718 = torch.aten.mul.Tensor %2712, %2717 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2718, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_3112 = torch.constant.int 5 + %2719 = torch.prims.convert_element_type %2718, %int5_3112 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2719, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %2720 = torch.aten.mul.Tensor %79, %2719 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2720, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_3113 = torch.constant.int 5 + %2721 = torch.prims.convert_element_type %2720, %int5_3113 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2721, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_3114 = torch.constant.int -2 + %int-1_3115 = torch.constant.int -1 + %2722 = torch.aten.transpose.int %80, %int-2_3114, %int-1_3115 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3116 = torch.constant.int 5 + %2723 = torch.prims.convert_element_type %2722, %int5_3116 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_3117 = torch.constant.int 4096 + %2724 = torch.prim.ListConstruct %342, %int4096_3117 : (!torch.int, !torch.int) -> !torch.list + %2725 = torch.aten.view %2721, %2724 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2725, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2726 = torch.aten.mm %2725, %2723 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %2726, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_3118 = torch.constant.int 4 + %int14336_3119 = torch.constant.int 14336 + %2727 = torch.prim.ListConstruct %int4_3118, %298, %int14336_3119 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2728 = torch.aten.view %2726, %2727 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %2728, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %2729 = torch.aten.silu %2728 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %2729, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_3120 = torch.constant.int -2 + %int-1_3121 = torch.constant.int -1 + %2730 = torch.aten.transpose.int %81, %int-2_3120, %int-1_3121 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3122 = torch.constant.int 5 + %2731 = torch.prims.convert_element_type %2730, %int5_3122 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_3123 = torch.constant.int 4096 + %2732 = torch.prim.ListConstruct %342, %int4096_3123 : (!torch.int, !torch.int) -> !torch.list + %2733 = torch.aten.view %2721, %2732 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2733, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2734 = torch.aten.mm %2733, %2731 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %2734, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_3124 = torch.constant.int 4 + %int14336_3125 = torch.constant.int 14336 + %2735 = torch.prim.ListConstruct %int4_3124, %298, %int14336_3125 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2736 = torch.aten.view %2734, %2735 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %2736, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %2737 = torch.aten.mul.Tensor %2729, %2736 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %2737, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_3126 = torch.constant.int -2 + %int-1_3127 = torch.constant.int -1 + %2738 = torch.aten.transpose.int %82, %int-2_3126, %int-1_3127 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_3128 = torch.constant.int 5 + %2739 = torch.prims.convert_element_type %2738, %int5_3128 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_3129 = torch.constant.int 14336 + %2740 = torch.prim.ListConstruct %342, %int14336_3129 : (!torch.int, !torch.int) -> !torch.list + %2741 = torch.aten.view %2737, %2740 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %2741, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %2742 = torch.aten.mm %2741, %2739 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2742, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_3130 = torch.constant.int 4 + %int4096_3131 = torch.constant.int 4096 + %2743 = torch.prim.ListConstruct %int4_3130, %298, %int4096_3131 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2744 = torch.aten.view %2742, %2743 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2744, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_3132 = torch.constant.int 1 + %2745 = torch.aten.add.Tensor %2711, %2744, %int1_3132 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2745, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_3133 = torch.constant.int 6 + %2746 = torch.prims.convert_element_type %2745, %int6_3133 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2746, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_3134 = torch.constant.int 2 + %2747 = torch.aten.pow.Tensor_Scalar %2746, %int2_3134 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2747, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_3135 = torch.constant.int -1 + %2748 = torch.prim.ListConstruct %int-1_3135 : (!torch.int) -> !torch.list + %true_3136 = torch.constant.bool true + %none_3137 = torch.constant.none + %2749 = torch.aten.mean.dim %2747, %2748, %true_3136, %none_3137 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2749, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_3138 = torch.constant.float 9.9999997473787516E-6 + %int1_3139 = torch.constant.int 1 + %2750 = torch.aten.add.Scalar %2749, %float9.999990e-06_3138, %int1_3139 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2750, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %2751 = torch.aten.rsqrt %2750 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2751, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %2752 = torch.aten.mul.Tensor %2746, %2751 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2752, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_3140 = torch.constant.int 5 + %2753 = torch.prims.convert_element_type %2752, %int5_3140 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2753, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %2754 = torch.aten.mul.Tensor %83, %2753 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2754, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_3141 = torch.constant.int 5 + %2755 = torch.prims.convert_element_type %2754, %int5_3141 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2755, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_3142 = torch.constant.int -2 + %int-1_3143 = torch.constant.int -1 + %2756 = torch.aten.transpose.int %84, %int-2_3142, %int-1_3143 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_3144 = torch.constant.int 5 + %2757 = torch.prims.convert_element_type %2756, %int5_3144 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_3145 = torch.constant.int 4096 + %2758 = torch.prim.ListConstruct %342, %int4096_3145 : (!torch.int, !torch.int) -> !torch.list + %2759 = torch.aten.view %2755, %2758 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2759, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2760 = torch.aten.mm %2759, %2757 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2760, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_3146 = torch.constant.int 4 + %int4096_3147 = torch.constant.int 4096 + %2761 = torch.prim.ListConstruct %int4_3146, %298, %int4096_3147 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2762 = torch.aten.view %2760, %2761 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2762, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_3148 = torch.constant.int -2 + %int-1_3149 = torch.constant.int -1 + %2763 = torch.aten.transpose.int %85, %int-2_3148, %int-1_3149 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_3150 = torch.constant.int 5 + %2764 = torch.prims.convert_element_type %2763, %int5_3150 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_3151 = torch.constant.int 4096 + %2765 = torch.prim.ListConstruct %342, %int4096_3151 : (!torch.int, !torch.int) -> !torch.list + %2766 = torch.aten.view %2755, %2765 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2766, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2767 = torch.aten.mm %2766, %2764 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %2767, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_3152 = torch.constant.int 4 + %int1024_3153 = torch.constant.int 1024 + %2768 = torch.prim.ListConstruct %int4_3152, %298, %int1024_3153 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2769 = torch.aten.view %2767, %2768 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %2769, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_3154 = torch.constant.int -2 + %int-1_3155 = torch.constant.int -1 + %2770 = torch.aten.transpose.int %86, %int-2_3154, %int-1_3155 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_3156 = torch.constant.int 5 + %2771 = torch.prims.convert_element_type %2770, %int5_3156 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_3157 = torch.constant.int 4096 + %2772 = torch.prim.ListConstruct %342, %int4096_3157 : (!torch.int, !torch.int) -> !torch.list + %2773 = torch.aten.view %2755, %2772 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2773, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2774 = torch.aten.mm %2773, %2771 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %2774, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> %int4_3158 = torch.constant.int 4 - %2762 = torch.aten.mul.int %int4_3158, %398 : !torch.int, !torch.int -> !torch.int - %2763 = torch.prim.ListConstruct %2762 : (!torch.int) -> !torch.list - %2764 = torch.aten.view %2756, %2763 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2764, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_3159 = torch.constant.int 32 - %int2_3160 = torch.constant.int 2 + %int1024_3159 = torch.constant.int 1024 + %2775 = torch.prim.ListConstruct %int4_3158, %298, %int1024_3159 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2776 = torch.aten.view %2774, %2775 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %2776, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_3160 = torch.constant.int 4 %int32_3161 = torch.constant.int 32 - %int8_3162 = torch.constant.int 8 - %int128_3163 = torch.constant.int 128 - %2765 = torch.prim.ListConstruct %389, %int32_3159, %int2_3160, %int32_3161, %int8_3162, %int128_3163 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2766 = torch.aten.view %2598, %2765 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2766, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_3164 = torch.constant.int 32 - %2767 = torch.aten.mul.int %389, %int32_3164 : !torch.int, !torch.int -> !torch.int - %int2_3165 = torch.constant.int 2 - %2768 = torch.aten.mul.int %2767, %int2_3165 : !torch.int, !torch.int -> !torch.int - %int32_3166 = torch.constant.int 32 + %int128_3162 = torch.constant.int 128 + %2777 = torch.prim.ListConstruct %int4_3160, %298, %int32_3161, %int128_3162 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2778 = torch.aten.view %2762, %2777 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2778, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_3163 = torch.constant.int 4 + %int8_3164 = torch.constant.int 8 + %int128_3165 = torch.constant.int 128 + %2779 = torch.prim.ListConstruct %int4_3163, %298, %int8_3164, %int128_3165 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2780 = torch.aten.view %2769, %2779 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2780, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_3166 = torch.constant.int 4 %int8_3167 = torch.constant.int 8 %int128_3168 = torch.constant.int 128 - %2769 = torch.prim.ListConstruct %2768, %int32_3166, %int8_3167, %int128_3168 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2770 = torch.aten.view %2766, %2769 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2770, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %2771 = torch.prim.ListConstruct %2764 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_3169 = torch.constant.bool false - %2772 = torch.aten.index_put %2770, %2771, %2761, %false_3169 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2772, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_3170 = torch.constant.int 32 - %int2_3171 = torch.constant.int 2 - %int32_3172 = torch.constant.int 32 - %int8_3173 = torch.constant.int 8 - %int128_3174 = torch.constant.int 128 - %2773 = torch.prim.ListConstruct %389, %int32_3170, %int2_3171, %int32_3172, %int8_3173, %int128_3174 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2774 = torch.aten.view %2772, %2773 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2774, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3175 = torch.constant.int 2097152 - %2775 = torch.prim.ListConstruct %389, %int2097152_3175 : (!torch.int, !torch.int) -> !torch.list - %2776 = torch.aten.view %2774, %2775 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2776, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_3176 = torch.constant.int 32 - %int2_3177 = torch.constant.int 2 - %int32_3178 = torch.constant.int 32 - %int8_3179 = torch.constant.int 8 - %int128_3180 = torch.constant.int 128 - %2777 = torch.prim.ListConstruct %389, %int32_3176, %int2_3177, %int32_3178, %int8_3179, %int128_3180 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2778 = torch.aten.view %2776, %2777 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2778, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_3181 = torch.constant.int 32 - %int8_3182 = torch.constant.int 8 - %int128_3183 = torch.constant.int 128 - %2779 = torch.prim.ListConstruct %2768, %int32_3181, %int8_3182, %int128_3183 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2780 = torch.aten.view %2778, %2779 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2780, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_3184 = torch.constant.int 4 - %int32_3185 = torch.constant.int 32 - %int8_3186 = torch.constant.int 8 - %int128_3187 = torch.constant.int 128 - %2781 = torch.prim.ListConstruct %int4_3184, %398, %int32_3185, %int8_3186, %int128_3187 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2782 = torch.aten.view %2698, %2781 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2782, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_3188 = torch.constant.int 4 - %2783 = torch.aten.mul.int %int4_3188, %398 : !torch.int, !torch.int -> !torch.int - %int32_3189 = torch.constant.int 32 - %int8_3190 = torch.constant.int 8 - %int128_3191 = torch.constant.int 128 - %2784 = torch.prim.ListConstruct %2783, %int32_3189, %int8_3190, %int128_3191 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2785 = torch.aten.view %2782, %2784 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2785, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %2781 = torch.prim.ListConstruct %int4_3166, %298, %int8_3167, %int128_3168 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2782 = torch.aten.view %2776, %2781 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2782, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_3169 = torch.constant.int 131072 + %none_3170 = torch.constant.none + %none_3171 = torch.constant.none + %cpu_3172 = torch.constant.device "cpu" + %false_3173 = torch.constant.bool false + %2783 = torch.aten.arange %int131072_3169, %none_3170, %none_3171, %cpu_3172, %false_3173 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_3174 = torch.constant.int 0 + %int128_3175 = torch.constant.int 128 + %int2_3176 = torch.constant.int 2 + %int4_3177 = torch.constant.int 4 + %none_3178 = torch.constant.none + %cpu_3179 = torch.constant.device "cpu" + %false_3180 = torch.constant.bool false + %2784 = torch.aten.arange.start_step %int0_3174, %int128_3175, %int2_3176, %int4_3177, %none_3178, %cpu_3179, %false_3180 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_3181 = torch.constant.int 6 + %2785 = torch.prims.convert_element_type %2784, %int6_3181 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_3182 = torch.constant.int 128 + %2786 = torch.aten.div.Scalar %2785, %int128_3182 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_3183 = torch.constant.float 5.000000e+05 + %2787 = torch.aten.pow.Scalar %float5.000000e05_3183, %2786 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2788 = torch.aten.reciprocal %2787 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_3184 = torch.constant.float 1.000000e+00 + %2789 = torch.aten.mul.Scalar %2788, %float1.000000e00_3184 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %2790 = torch.aten.reciprocal %2789 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_3185 = torch.constant.float 6.2831853071795862 + %2791 = torch.aten.mul.Scalar %2790, %float6.283190e00_3185 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_3186 = torch.constant.float 8.192000e+03 + %2792 = torch.aten.gt.Scalar %2791, %float8.192000e03_3186 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_3187 = torch.constant.int 8 + %2793 = torch.aten.div.Scalar %2789, %int8_3187 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %2794 = torch.aten.where.self %2792, %2793, %2789 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2795 = torch.aten.reciprocal %2791 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_3188 = torch.constant.int 8192 + %2796 = torch.aten.mul.Scalar %2795, %int8192_3188 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_3189 = torch.constant.int 1 + %int1_3190 = torch.constant.int 1 + %2797 = torch.aten.sub.Scalar %2796, %int1_3189, %int1_3190 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_3191 = torch.constant.int 3 + %2798 = torch.aten.div.Scalar %2797, %int3_3191 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_3192 = torch.constant.int 1 %int1_3193 = torch.constant.int 1 - %2786 = torch.aten.add.Scalar %2756, %int1_3192, %int1_3193 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2786, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_3194 = torch.constant.int 4 - %2787 = torch.aten.mul.int %int4_3194, %398 : !torch.int, !torch.int -> !torch.int - %2788 = torch.prim.ListConstruct %2787 : (!torch.int) -> !torch.list - %2789 = torch.aten.view %2786, %2788 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2789, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %2790 = torch.prim.ListConstruct %2789 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_3195 = torch.constant.bool false - %2791 = torch.aten.index_put %2780, %2790, %2785, %false_3195 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2791, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_3196 = torch.constant.int 32 - %int2_3197 = torch.constant.int 2 - %int32_3198 = torch.constant.int 32 - %int8_3199 = torch.constant.int 8 - %int128_3200 = torch.constant.int 128 - %2792 = torch.prim.ListConstruct %389, %int32_3196, %int2_3197, %int32_3198, %int8_3199, %int128_3200 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2793 = torch.aten.view %2791, %2792 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2793, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3201 = torch.constant.int 2097152 - %2794 = torch.prim.ListConstruct %389, %int2097152_3201 : (!torch.int, !torch.int) -> !torch.list - %2795 = torch.aten.view %2793, %2794 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2795, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_3202 = torch.constant.int -2 - %2796 = torch.aten.unsqueeze %2754, %int-2_3202 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2796, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_3203 = torch.constant.int 4 - %int8_3204 = torch.constant.int 8 - %int4_3205 = torch.constant.int 4 - %int128_3206 = torch.constant.int 128 - %2797 = torch.prim.ListConstruct %int4_3203, %2739, %int8_3204, %int4_3205, %int128_3206 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_3207 = torch.constant.bool false - %2798 = torch.aten.expand %2796, %2797, %false_3207 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2798, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_3208 = torch.constant.int 0 - %2799 = torch.aten.clone %2798, %int0_3208 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2799, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_3209 = torch.constant.int 4 - %int32_3210 = torch.constant.int 32 - %int128_3211 = torch.constant.int 128 - %2800 = torch.prim.ListConstruct %int4_3209, %2739, %int32_3210, %int128_3211 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2801 = torch.aten._unsafe_view %2799, %2800 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2801, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_3212 = torch.constant.int -2 - %2802 = torch.aten.unsqueeze %2698, %int-2_3212 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2802, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_3213 = torch.constant.int 1 - %2803 = torch.aten.size.int %2692, %int1_3213 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_3214 = torch.constant.int 4 - %int8_3215 = torch.constant.int 8 - %int4_3216 = torch.constant.int 4 - %int128_3217 = torch.constant.int 128 - %2804 = torch.prim.ListConstruct %int4_3214, %2803, %int8_3215, %int4_3216, %int128_3217 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_3218 = torch.constant.bool false - %2805 = torch.aten.expand %2802, %2804, %false_3218 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2805, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_3219 = torch.constant.int 0 - %2806 = torch.aten.clone %2805, %int0_3219 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2806, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_3220 = torch.constant.int 4 - %int32_3221 = torch.constant.int 32 - %int128_3222 = torch.constant.int 128 - %2807 = torch.prim.ListConstruct %int4_3220, %2803, %int32_3221, %int128_3222 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2808 = torch.aten._unsafe_view %2806, %2807 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2808, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_3223 = torch.constant.int 1 - %int2_3224 = torch.constant.int 2 - %2809 = torch.aten.transpose.int %2726, %int1_3223, %int2_3224 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2809, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_3225 = torch.constant.int 1 - %int2_3226 = torch.constant.int 2 - %2810 = torch.aten.transpose.int %2801, %int1_3225, %int2_3226 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2810, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_3227 = torch.constant.int 1 - %int2_3228 = torch.constant.int 2 - %2811 = torch.aten.transpose.int %2808, %int1_3227, %int2_3228 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2811, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_3229 = torch.constant.float 0.000000e+00 - %true_3230 = torch.constant.bool true - %none_3231 = torch.constant.none - %none_3232 = torch.constant.none - %2812:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2809, %2810, %2811, %float0.000000e00_3229, %true_3230, %none_3231, %none_3232) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %2812#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %2799 = torch.aten.rsub.Scalar %2798, %int1_3192, %int1_3193 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %2800 = torch.aten.mul.Tensor %2799, %2794 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_3194 = torch.constant.int 8 + %2801 = torch.aten.div.Scalar %2800, %int8_3194 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %2802 = torch.aten.mul.Tensor %2798, %2794 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_3195 = torch.constant.int 1 + %2803 = torch.aten.add.Tensor %2801, %2802, %int1_3195 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_3196 = torch.constant.float 2.048000e+03 + %2804 = torch.aten.lt.Scalar %2791, %float2.048000e03_3196 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2805 = torch.aten.bitwise_not %2804 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_3197 = torch.constant.float 8.192000e+03 + %2806 = torch.aten.gt.Scalar %2791, %float8.192000e03_3197 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2807 = torch.aten.bitwise_not %2806 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2808 = torch.aten.mul.Tensor %2805, %2807 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2809 = torch.aten.where.self %2808, %2803, %2794 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2810 = torch.prim.ListConstruct %2809, %2809 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_3198 = torch.constant.int -1 + %2811 = torch.aten.cat %2810, %int-1_3198 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_3199 = torch.constant.int 6 + %2812 = torch.prims.convert_element_type %2811, %int6_3199 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_3200 = torch.constant.int 1 + %2813 = torch.aten.unsqueeze %2783, %int1_3200 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_3201 = torch.constant.int 6 + %2814 = torch.prims.convert_element_type %2813, %int6_3201 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_3202 = torch.constant.int 0 + %2815 = torch.aten.unsqueeze %2812, %int0_3202 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_3203 = torch.constant.int 6 + %2816 = torch.prims.convert_element_type %2815, %int6_3203 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %2817 = torch.aten.mul.Tensor %2814, %2816 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %2818 = torch.aten.cos %2817 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_3204 = torch.constant.int 5 + %2819 = torch.prims.convert_element_type %2818, %int5_3204 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %2820 = torch.aten.sin %2817 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_3205 = torch.constant.int 5 + %2821 = torch.prims.convert_element_type %2820, %int5_3205 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_3206 = torch.constant.int 0 + %int0_3207 = torch.constant.int 0 + %int1_3208 = torch.constant.int 1 + %2822 = torch.aten.slice.Tensor %2819, %int0_3206, %int0_3207, %298, %int1_3208 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2822, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_3209 = torch.constant.int 1 + %int0_3210 = torch.constant.int 0 + %int9223372036854775807_3211 = torch.constant.int 9223372036854775807 + %int1_3212 = torch.constant.int 1 + %2823 = torch.aten.slice.Tensor %2822, %int1_3209, %int0_3210, %int9223372036854775807_3211, %int1_3212 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2823, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_3213 = torch.constant.int 0 + %int0_3214 = torch.constant.int 0 + %int1_3215 = torch.constant.int 1 + %2824 = torch.aten.slice.Tensor %2821, %int0_3213, %int0_3214, %298, %int1_3215 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2824, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_3216 = torch.constant.int 1 + %int0_3217 = torch.constant.int 0 + %int9223372036854775807_3218 = torch.constant.int 9223372036854775807 + %int1_3219 = torch.constant.int 1 + %2825 = torch.aten.slice.Tensor %2824, %int1_3216, %int0_3217, %int9223372036854775807_3218, %int1_3219 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2825, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_3220 = torch.constant.int 0 + %2826 = torch.aten.unsqueeze %2823, %int0_3220 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2826, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_3221 = torch.constant.int 1 + %int0_3222 = torch.constant.int 0 + %int9223372036854775807_3223 = torch.constant.int 9223372036854775807 + %int1_3224 = torch.constant.int 1 + %2827 = torch.aten.slice.Tensor %2826, %int1_3221, %int0_3222, %int9223372036854775807_3223, %int1_3224 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2827, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_3225 = torch.constant.int 2 + %2828 = torch.aten.unsqueeze %2827, %int2_3225 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2828, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_3226 = torch.constant.int 3 + %int0_3227 = torch.constant.int 0 + %int9223372036854775807_3228 = torch.constant.int 9223372036854775807 + %int1_3229 = torch.constant.int 1 + %2829 = torch.aten.slice.Tensor %2828, %int3_3226, %int0_3227, %int9223372036854775807_3228, %int1_3229 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2829, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_3230 = torch.constant.int 4 + %int1_3231 = torch.constant.int 1 + %int1_3232 = torch.constant.int 1 %int1_3233 = torch.constant.int 1 - %int2_3234 = torch.constant.int 2 - %2813 = torch.aten.transpose.int %2812#0, %int1_3233, %int2_3234 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2813, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_3235 = torch.constant.int 4 - %int4096_3236 = torch.constant.int 4096 - %2814 = torch.prim.ListConstruct %int4_3235, %2711, %int4096_3236 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2815 = torch.aten.view %2813, %2814 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2815, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3237 = torch.constant.int -2 - %int-1_3238 = torch.constant.int -1 - %2816 = torch.aten.transpose.int %113, %int-2_3237, %int-1_3238 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3239 = torch.constant.int 4 - %2817 = torch.aten.mul.int %int4_3239, %2711 : !torch.int, !torch.int -> !torch.int - %int4096_3240 = torch.constant.int 4096 - %2818 = torch.prim.ListConstruct %2817, %int4096_3240 : (!torch.int, !torch.int) -> !torch.list - %2819 = torch.aten.view %2815, %2818 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2819, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2820 = torch.aten.mm %2819, %2816 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2820, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_3241 = torch.constant.int 4 - %int4096_3242 = torch.constant.int 4096 - %2821 = torch.prim.ListConstruct %int4_3241, %2711, %int4096_3242 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2822 = torch.aten.view %2820, %2821 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2822, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %2830 = torch.prim.ListConstruct %int4_3230, %int1_3231, %int1_3232, %int1_3233 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2831 = torch.aten.repeat %2829, %2830 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2831, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_3234 = torch.constant.int 0 + %2832 = torch.aten.unsqueeze %2825, %int0_3234 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2832, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_3235 = torch.constant.int 1 + %int0_3236 = torch.constant.int 0 + %int9223372036854775807_3237 = torch.constant.int 9223372036854775807 + %int1_3238 = torch.constant.int 1 + %2833 = torch.aten.slice.Tensor %2832, %int1_3235, %int0_3236, %int9223372036854775807_3237, %int1_3238 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2833, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_3239 = torch.constant.int 2 + %2834 = torch.aten.unsqueeze %2833, %int2_3239 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2834, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_3240 = torch.constant.int 3 + %int0_3241 = torch.constant.int 0 + %int9223372036854775807_3242 = torch.constant.int 9223372036854775807 %int1_3243 = torch.constant.int 1 - %2823 = torch.aten.add.Tensor %2661, %2822, %int1_3243 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2823, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_3244 = torch.constant.int 6 - %2824 = torch.prims.convert_element_type %2823, %int6_3244 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2824, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_3245 = torch.constant.int 2 - %2825 = torch.aten.pow.Tensor_Scalar %2824, %int2_3245 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2825, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_3246 = torch.constant.int -1 - %2826 = torch.prim.ListConstruct %int-1_3246 : (!torch.int) -> !torch.list - %true_3247 = torch.constant.bool true - %none_3248 = torch.constant.none - %2827 = torch.aten.mean.dim %2825, %2826, %true_3247, %none_3248 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2827, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_3249 = torch.constant.float 9.9999997473787516E-6 - %int1_3250 = torch.constant.int 1 - %2828 = torch.aten.add.Scalar %2827, %float9.999990e-06_3249, %int1_3250 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2828, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2829 = torch.aten.rsqrt %2828 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2829, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2830 = torch.aten.mul.Tensor %2824, %2829 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2830, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_3251 = torch.constant.int 5 - %2831 = torch.prims.convert_element_type %2830, %int5_3251 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2831, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %2832 = torch.aten.mul.Tensor %114, %2831 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2832, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_3252 = torch.constant.int 5 - %2833 = torch.prims.convert_element_type %2832, %int5_3252 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2833, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3253 = torch.constant.int -2 - %int-1_3254 = torch.constant.int -1 - %2834 = torch.aten.transpose.int %115, %int-2_3253, %int-1_3254 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3255 = torch.constant.int 4 - %2835 = torch.aten.mul.int %int4_3255, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3256 = torch.constant.int 4096 - %2836 = torch.prim.ListConstruct %2835, %int4096_3256 : (!torch.int, !torch.int) -> !torch.list - %2837 = torch.aten.view %2833, %2836 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2837, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2838 = torch.aten.mm %2837, %2834 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2838, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_3257 = torch.constant.int 4 - %int14336_3258 = torch.constant.int 14336 - %2839 = torch.prim.ListConstruct %int4_3257, %306, %int14336_3258 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2840 = torch.aten.view %2838, %2839 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2840, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %2841 = torch.aten.silu %2840 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2841, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_3259 = torch.constant.int -2 - %int-1_3260 = torch.constant.int -1 - %2842 = torch.aten.transpose.int %116, %int-2_3259, %int-1_3260 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3261 = torch.constant.int 4 - %2843 = torch.aten.mul.int %int4_3261, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3262 = torch.constant.int 4096 - %2844 = torch.prim.ListConstruct %2843, %int4096_3262 : (!torch.int, !torch.int) -> !torch.list - %2845 = torch.aten.view %2833, %2844 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2845, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2846 = torch.aten.mm %2845, %2842 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2846, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_3263 = torch.constant.int 4 - %int14336_3264 = torch.constant.int 14336 - %2847 = torch.prim.ListConstruct %int4_3263, %306, %int14336_3264 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2848 = torch.aten.view %2846, %2847 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2848, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %2849 = torch.aten.mul.Tensor %2841, %2848 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %2849, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_3265 = torch.constant.int -2 - %int-1_3266 = torch.constant.int -1 - %2850 = torch.aten.transpose.int %117, %int-2_3265, %int-1_3266 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_3267 = torch.constant.int 1 - %2851 = torch.aten.size.int %2840, %int1_3267 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_3268 = torch.constant.int 4 - %2852 = torch.aten.mul.int %int4_3268, %2851 : !torch.int, !torch.int -> !torch.int - %int14336_3269 = torch.constant.int 14336 - %2853 = torch.prim.ListConstruct %2852, %int14336_3269 : (!torch.int, !torch.int) -> !torch.list - %2854 = torch.aten.view %2849, %2853 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %2854, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %2855 = torch.aten.mm %2854, %2850 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2855, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_3270 = torch.constant.int 4 - %int4096_3271 = torch.constant.int 4096 - %2856 = torch.prim.ListConstruct %int4_3270, %2851, %int4096_3271 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2857 = torch.aten.view %2855, %2856 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2857, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_3272 = torch.constant.int 1 - %2858 = torch.aten.add.Tensor %2823, %2857, %int1_3272 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2858, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_3273 = torch.constant.int 6 - %2859 = torch.prims.convert_element_type %2858, %int6_3273 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2859, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_3274 = torch.constant.int 2 - %2860 = torch.aten.pow.Tensor_Scalar %2859, %int2_3274 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2860, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_3275 = torch.constant.int -1 - %2861 = torch.prim.ListConstruct %int-1_3275 : (!torch.int) -> !torch.list - %true_3276 = torch.constant.bool true - %none_3277 = torch.constant.none - %2862 = torch.aten.mean.dim %2860, %2861, %true_3276, %none_3277 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2862, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_3278 = torch.constant.float 9.9999997473787516E-6 + %2835 = torch.aten.slice.Tensor %2834, %int3_3240, %int0_3241, %int9223372036854775807_3242, %int1_3243 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2835, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_3244 = torch.constant.int 4 + %int1_3245 = torch.constant.int 1 + %int1_3246 = torch.constant.int 1 + %int1_3247 = torch.constant.int 1 + %2836 = torch.prim.ListConstruct %int4_3244, %int1_3245, %int1_3246, %int1_3247 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2837 = torch.aten.repeat %2835, %2836 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2837, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %2838 = torch.aten.mul.Tensor %2778, %2831 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2838, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_3248 = torch.constant.int 3 + %int0_3249 = torch.constant.int 0 + %int64_3250 = torch.constant.int 64 + %int1_3251 = torch.constant.int 1 + %2839 = torch.aten.slice.Tensor %2778, %int3_3248, %int0_3249, %int64_3250, %int1_3251 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %2839, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_3252 = torch.constant.int 3 + %int64_3253 = torch.constant.int 64 + %int9223372036854775807_3254 = torch.constant.int 9223372036854775807 + %int1_3255 = torch.constant.int 1 + %2840 = torch.aten.slice.Tensor %2778, %int3_3252, %int64_3253, %int9223372036854775807_3254, %int1_3255 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %2840, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %2841 = torch.aten.neg %2840 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %2841, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %2842 = torch.prim.ListConstruct %2841, %2839 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_3256 = torch.constant.int -1 + %2843 = torch.aten.cat %2842, %int-1_3256 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2843, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %2844 = torch.aten.mul.Tensor %2843, %2837 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2844, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_3257 = torch.constant.int 1 + %2845 = torch.aten.add.Tensor %2838, %2844, %int1_3257 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2845, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_3258 = torch.constant.int 131072 + %none_3259 = torch.constant.none + %none_3260 = torch.constant.none + %cpu_3261 = torch.constant.device "cpu" + %false_3262 = torch.constant.bool false + %2846 = torch.aten.arange %int131072_3258, %none_3259, %none_3260, %cpu_3261, %false_3262 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_3263 = torch.constant.int 0 + %int128_3264 = torch.constant.int 128 + %int2_3265 = torch.constant.int 2 + %int4_3266 = torch.constant.int 4 + %none_3267 = torch.constant.none + %cpu_3268 = torch.constant.device "cpu" + %false_3269 = torch.constant.bool false + %2847 = torch.aten.arange.start_step %int0_3263, %int128_3264, %int2_3265, %int4_3266, %none_3267, %cpu_3268, %false_3269 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_3270 = torch.constant.int 6 + %2848 = torch.prims.convert_element_type %2847, %int6_3270 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_3271 = torch.constant.int 128 + %2849 = torch.aten.div.Scalar %2848, %int128_3271 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_3272 = torch.constant.float 5.000000e+05 + %2850 = torch.aten.pow.Scalar %float5.000000e05_3272, %2849 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2851 = torch.aten.reciprocal %2850 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_3273 = torch.constant.float 1.000000e+00 + %2852 = torch.aten.mul.Scalar %2851, %float1.000000e00_3273 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %2853 = torch.aten.reciprocal %2852 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_3274 = torch.constant.float 6.2831853071795862 + %2854 = torch.aten.mul.Scalar %2853, %float6.283190e00_3274 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_3275 = torch.constant.float 8.192000e+03 + %2855 = torch.aten.gt.Scalar %2854, %float8.192000e03_3275 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_3276 = torch.constant.int 8 + %2856 = torch.aten.div.Scalar %2852, %int8_3276 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %2857 = torch.aten.where.self %2855, %2856, %2852 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2858 = torch.aten.reciprocal %2854 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_3277 = torch.constant.int 8192 + %2859 = torch.aten.mul.Scalar %2858, %int8192_3277 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_3278 = torch.constant.int 1 %int1_3279 = torch.constant.int 1 - %2863 = torch.aten.add.Scalar %2862, %float9.999990e-06_3278, %int1_3279 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2863, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2864 = torch.aten.rsqrt %2863 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %2864, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %2865 = torch.aten.mul.Tensor %2859, %2864 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2865, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_3280 = torch.constant.int 5 - %2866 = torch.prims.convert_element_type %2865, %int5_3280 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2866, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %2867 = torch.aten.mul.Tensor %118, %2866 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %2867, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_3281 = torch.constant.int 5 - %2868 = torch.prims.convert_element_type %2867, %int5_3281 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2868, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3282 = torch.constant.int -2 - %int-1_3283 = torch.constant.int -1 - %2869 = torch.aten.transpose.int %119, %int-2_3282, %int-1_3283 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3284 = torch.constant.int 4 - %2870 = torch.aten.mul.int %int4_3284, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3285 = torch.constant.int 4096 - %2871 = torch.prim.ListConstruct %2870, %int4096_3285 : (!torch.int, !torch.int) -> !torch.list - %2872 = torch.aten.view %2868, %2871 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2872, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2873 = torch.aten.mm %2872, %2869 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2873, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_3286 = torch.constant.int 4 - %int4096_3287 = torch.constant.int 4096 - %2874 = torch.prim.ListConstruct %int4_3286, %306, %int4096_3287 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2875 = torch.aten.view %2873, %2874 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %2875, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3288 = torch.constant.int -2 - %int-1_3289 = torch.constant.int -1 - %2876 = torch.aten.transpose.int %120, %int-2_3288, %int-1_3289 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3290 = torch.constant.int 4 - %2877 = torch.aten.mul.int %int4_3290, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3291 = torch.constant.int 4096 - %2878 = torch.prim.ListConstruct %2877, %int4096_3291 : (!torch.int, !torch.int) -> !torch.list - %2879 = torch.aten.view %2868, %2878 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2879, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2880 = torch.aten.mm %2879, %2876 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %2880, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_3292 = torch.constant.int 4 - %int1024_3293 = torch.constant.int 1024 - %2881 = torch.prim.ListConstruct %int4_3292, %306, %int1024_3293 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2882 = torch.aten.view %2880, %2881 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %2882, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_3294 = torch.constant.int -2 - %int-1_3295 = torch.constant.int -1 - %2883 = torch.aten.transpose.int %121, %int-2_3294, %int-1_3295 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3296 = torch.constant.int 4 - %2884 = torch.aten.mul.int %int4_3296, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3297 = torch.constant.int 4096 - %2885 = torch.prim.ListConstruct %2884, %int4096_3297 : (!torch.int, !torch.int) -> !torch.list - %2886 = torch.aten.view %2868, %2885 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %2886, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %2887 = torch.aten.mm %2886, %2883 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %2887, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_3298 = torch.constant.int 4 - %int1024_3299 = torch.constant.int 1024 - %2888 = torch.prim.ListConstruct %int4_3298, %306, %int1024_3299 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2889 = torch.aten.view %2887, %2888 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %2889, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_3300 = torch.constant.int 4 - %int32_3301 = torch.constant.int 32 - %int128_3302 = torch.constant.int 128 - %2890 = torch.prim.ListConstruct %int4_3300, %306, %int32_3301, %int128_3302 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2891 = torch.aten.view %2875, %2890 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2891, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_3303 = torch.constant.int 4 - %int8_3304 = torch.constant.int 8 - %int128_3305 = torch.constant.int 128 - %2892 = torch.prim.ListConstruct %int4_3303, %306, %int8_3304, %int128_3305 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2893 = torch.aten.view %2882, %2892 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2893, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_3306 = torch.constant.int 4 - %int8_3307 = torch.constant.int 8 - %int128_3308 = torch.constant.int 128 - %2894 = torch.prim.ListConstruct %int4_3306, %306, %int8_3307, %int128_3308 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2895 = torch.aten.view %2889, %2894 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2895, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_3309 = torch.constant.int 131072 - %none_3310 = torch.constant.none - %none_3311 = torch.constant.none - %cpu_3312 = torch.constant.device "cpu" - %false_3313 = torch.constant.bool false - %2896 = torch.aten.arange %int131072_3309, %none_3310, %none_3311, %cpu_3312, %false_3313 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_3314 = torch.constant.int 0 - %int128_3315 = torch.constant.int 128 - %none_3316 = torch.constant.none - %none_3317 = torch.constant.none - %cpu_3318 = torch.constant.device "cpu" - %false_3319 = torch.constant.bool false - %2897 = torch.aten.arange.start %int0_3314, %int128_3315, %none_3316, %none_3317, %cpu_3318, %false_3319 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_3320 = torch.constant.int 2 - %2898 = torch.aten.floor_divide.Scalar %2897, %int2_3320 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_3321 = torch.constant.int 6 - %2899 = torch.prims.convert_element_type %2898, %int6_3321 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_3322 = torch.constant.int 128 - %2900 = torch.aten.div.Scalar %2899, %int128_3322 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_3323 = torch.constant.float 2.000000e+00 - %2901 = torch.aten.mul.Scalar %2900, %float2.000000e00_3323 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_3324 = torch.constant.float 5.000000e+05 - %2902 = torch.aten.pow.Scalar %float5.000000e05_3324, %2901 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %2903 = torch.aten.reciprocal %2902 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_3325 = torch.constant.float 1.000000e+00 - %2904 = torch.aten.mul.Scalar %2903, %float1.000000e00_3325 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_3326 = torch.constant.int 1 - %2905 = torch.aten.unsqueeze %2896, %int1_3326 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_3327 = torch.constant.int 0 - %2906 = torch.aten.unsqueeze %2904, %int0_3327 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %2907 = torch.aten.mul.Tensor %2905, %2906 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_3328 = torch.constant.int 1 - %2908 = torch.aten.size.int %2875, %int1_3328 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_3329 = torch.constant.int 0 - %2909 = torch.aten.add.int %int0_3329, %2908 : !torch.int, !torch.int -> !torch.int + %2860 = torch.aten.sub.Scalar %2859, %int1_3278, %int1_3279 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_3280 = torch.constant.int 3 + %2861 = torch.aten.div.Scalar %2860, %int3_3280 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_3281 = torch.constant.int 1 + %int1_3282 = torch.constant.int 1 + %2862 = torch.aten.rsub.Scalar %2861, %int1_3281, %int1_3282 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %2863 = torch.aten.mul.Tensor %2862, %2857 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_3283 = torch.constant.int 8 + %2864 = torch.aten.div.Scalar %2863, %int8_3283 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %2865 = torch.aten.mul.Tensor %2861, %2857 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_3284 = torch.constant.int 1 + %2866 = torch.aten.add.Tensor %2864, %2865, %int1_3284 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_3285 = torch.constant.float 2.048000e+03 + %2867 = torch.aten.lt.Scalar %2854, %float2.048000e03_3285 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2868 = torch.aten.bitwise_not %2867 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_3286 = torch.constant.float 8.192000e+03 + %2869 = torch.aten.gt.Scalar %2854, %float8.192000e03_3286 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %2870 = torch.aten.bitwise_not %2869 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2871 = torch.aten.mul.Tensor %2868, %2870 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %2872 = torch.aten.where.self %2871, %2866, %2857 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %2873 = torch.prim.ListConstruct %2872, %2872 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_3287 = torch.constant.int -1 + %2874 = torch.aten.cat %2873, %int-1_3287 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_3288 = torch.constant.int 6 + %2875 = torch.prims.convert_element_type %2874, %int6_3288 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_3289 = torch.constant.int 1 + %2876 = torch.aten.unsqueeze %2846, %int1_3289 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_3290 = torch.constant.int 6 + %2877 = torch.prims.convert_element_type %2876, %int6_3290 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_3291 = torch.constant.int 0 + %2878 = torch.aten.unsqueeze %2875, %int0_3291 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_3292 = torch.constant.int 6 + %2879 = torch.prims.convert_element_type %2878, %int6_3292 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %2880 = torch.aten.mul.Tensor %2877, %2879 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %2881 = torch.aten.cos %2880 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_3293 = torch.constant.int 5 + %2882 = torch.prims.convert_element_type %2881, %int5_3293 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %2883 = torch.aten.sin %2880 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_3294 = torch.constant.int 5 + %2884 = torch.prims.convert_element_type %2883, %int5_3294 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_3295 = torch.constant.int 0 + %int0_3296 = torch.constant.int 0 + %int1_3297 = torch.constant.int 1 + %2885 = torch.aten.slice.Tensor %2882, %int0_3295, %int0_3296, %298, %int1_3297 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2885, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_3298 = torch.constant.int 1 + %int0_3299 = torch.constant.int 0 + %int9223372036854775807_3300 = torch.constant.int 9223372036854775807 + %int1_3301 = torch.constant.int 1 + %2886 = torch.aten.slice.Tensor %2885, %int1_3298, %int0_3299, %int9223372036854775807_3300, %int1_3301 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2886, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_3302 = torch.constant.int 0 + %int0_3303 = torch.constant.int 0 + %int1_3304 = torch.constant.int 1 + %2887 = torch.aten.slice.Tensor %2884, %int0_3302, %int0_3303, %298, %int1_3304 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2887, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_3305 = torch.constant.int 1 + %int0_3306 = torch.constant.int 0 + %int9223372036854775807_3307 = torch.constant.int 9223372036854775807 + %int1_3308 = torch.constant.int 1 + %2888 = torch.aten.slice.Tensor %2887, %int1_3305, %int0_3306, %int9223372036854775807_3307, %int1_3308 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2888, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_3309 = torch.constant.int 0 + %2889 = torch.aten.unsqueeze %2886, %int0_3309 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2889, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_3310 = torch.constant.int 1 + %int0_3311 = torch.constant.int 0 + %int9223372036854775807_3312 = torch.constant.int 9223372036854775807 + %int1_3313 = torch.constant.int 1 + %2890 = torch.aten.slice.Tensor %2889, %int1_3310, %int0_3311, %int9223372036854775807_3312, %int1_3313 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2890, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_3314 = torch.constant.int 2 + %2891 = torch.aten.unsqueeze %2890, %int2_3314 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2891, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_3315 = torch.constant.int 3 + %int0_3316 = torch.constant.int 0 + %int9223372036854775807_3317 = torch.constant.int 9223372036854775807 + %int1_3318 = torch.constant.int 1 + %2892 = torch.aten.slice.Tensor %2891, %int3_3315, %int0_3316, %int9223372036854775807_3317, %int1_3318 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2892, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_3319 = torch.constant.int 4 + %int1_3320 = torch.constant.int 1 + %int1_3321 = torch.constant.int 1 + %int1_3322 = torch.constant.int 1 + %2893 = torch.prim.ListConstruct %int4_3319, %int1_3320, %int1_3321, %int1_3322 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2894 = torch.aten.repeat %2892, %2893 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2894, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_3323 = torch.constant.int 0 + %2895 = torch.aten.unsqueeze %2888, %int0_3323 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2895, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_3324 = torch.constant.int 1 + %int0_3325 = torch.constant.int 0 + %int9223372036854775807_3326 = torch.constant.int 9223372036854775807 + %int1_3327 = torch.constant.int 1 + %2896 = torch.aten.slice.Tensor %2895, %int1_3324, %int0_3325, %int9223372036854775807_3326, %int1_3327 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %2896, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_3328 = torch.constant.int 2 + %2897 = torch.aten.unsqueeze %2896, %int2_3328 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2897, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_3329 = torch.constant.int 3 %int0_3330 = torch.constant.int 0 - %int0_3331 = torch.constant.int 0 + %int9223372036854775807_3331 = torch.constant.int 9223372036854775807 %int1_3332 = torch.constant.int 1 - %2910 = torch.aten.slice.Tensor %2907, %int0_3330, %int0_3331, %2909, %int1_3332 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2910, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_3333 = torch.constant.int 1 - %int0_3334 = torch.constant.int 0 - %int9223372036854775807_3335 = torch.constant.int 9223372036854775807 + %2898 = torch.aten.slice.Tensor %2897, %int3_3329, %int0_3330, %int9223372036854775807_3331, %int1_3332 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %2898, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_3333 = torch.constant.int 4 + %int1_3334 = torch.constant.int 1 + %int1_3335 = torch.constant.int 1 %int1_3336 = torch.constant.int 1 - %2911 = torch.aten.slice.Tensor %2910, %int1_3333, %int0_3334, %int9223372036854775807_3335, %int1_3336 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2911, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_3337 = torch.constant.int 1 + %2899 = torch.prim.ListConstruct %int4_3333, %int1_3334, %int1_3335, %int1_3336 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2900 = torch.aten.repeat %2898, %2899 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %2900, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %2901 = torch.aten.mul.Tensor %2780, %2894 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2901, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_3337 = torch.constant.int 3 %int0_3338 = torch.constant.int 0 - %int9223372036854775807_3339 = torch.constant.int 9223372036854775807 + %int64_3339 = torch.constant.int 64 %int1_3340 = torch.constant.int 1 - %2912 = torch.aten.slice.Tensor %2911, %int1_3337, %int0_3338, %int9223372036854775807_3339, %int1_3340 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2912, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_3341 = torch.constant.int 0 - %2913 = torch.aten.unsqueeze %2912, %int0_3341 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2913, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_3342 = torch.constant.int 1 - %int0_3343 = torch.constant.int 0 - %int9223372036854775807_3344 = torch.constant.int 9223372036854775807 - %int1_3345 = torch.constant.int 1 - %2914 = torch.aten.slice.Tensor %2913, %int1_3342, %int0_3343, %int9223372036854775807_3344, %int1_3345 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2914, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_3346 = torch.constant.int 2 - %int0_3347 = torch.constant.int 0 - %int9223372036854775807_3348 = torch.constant.int 9223372036854775807 - %int1_3349 = torch.constant.int 1 - %2915 = torch.aten.slice.Tensor %2914, %int2_3346, %int0_3347, %int9223372036854775807_3348, %int1_3349 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2915, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_3350 = torch.constant.int 4 + %2902 = torch.aten.slice.Tensor %2780, %int3_3337, %int0_3338, %int64_3339, %int1_3340 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %2902, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_3341 = torch.constant.int 3 + %int64_3342 = torch.constant.int 64 + %int9223372036854775807_3343 = torch.constant.int 9223372036854775807 + %int1_3344 = torch.constant.int 1 + %2903 = torch.aten.slice.Tensor %2780, %int3_3341, %int64_3342, %int9223372036854775807_3343, %int1_3344 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %2903, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %2904 = torch.aten.neg %2903 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %2904, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %2905 = torch.prim.ListConstruct %2904, %2902 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_3345 = torch.constant.int -1 + %2906 = torch.aten.cat %2905, %int-1_3345 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2906, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %2907 = torch.aten.mul.Tensor %2906, %2900 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2907, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_3346 = torch.constant.int 1 + %2908 = torch.aten.add.Tensor %2901, %2907, %int1_3346 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2908, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_3347 = torch.constant.int 32 + %2909 = torch.aten.mul.Scalar %arg2, %int32_3347 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2909, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int9 = torch.constant.int 9 + %int1_3348 = torch.constant.int 1 + %2910 = torch.aten.add.Scalar %2909, %int9, %int1_3348 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2910, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_3349 = torch.constant.int 2 + %2911 = torch.aten.mul.Scalar %2910, %int2_3349 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2911, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_3350 = torch.constant.int 0 %int1_3351 = torch.constant.int 1 - %int1_3352 = torch.constant.int 1 - %2916 = torch.prim.ListConstruct %int4_3350, %int1_3351, %int1_3352 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2917 = torch.aten.repeat %2915, %2916 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %2917, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_3353 = torch.constant.int 6 - %2918 = torch.prims.convert_element_type %2891, %int6_3353 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %2918, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %2919 = torch_c.to_builtin_tensor %2918 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %2920 = torch_c.to_builtin_tensor %2917 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %2921 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%2919, %2920) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %2922 = torch_c.from_builtin_tensor %2921 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %2922, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_3354 = torch.constant.int 5 - %2923 = torch.prims.convert_element_type %2922, %int5_3354 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2923, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_3355 = torch.constant.int 131072 - %none_3356 = torch.constant.none - %none_3357 = torch.constant.none - %cpu_3358 = torch.constant.device "cpu" - %false_3359 = torch.constant.bool false - %2924 = torch.aten.arange %int131072_3355, %none_3356, %none_3357, %cpu_3358, %false_3359 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_3360 = torch.constant.int 0 - %int128_3361 = torch.constant.int 128 - %none_3362 = torch.constant.none - %none_3363 = torch.constant.none - %cpu_3364 = torch.constant.device "cpu" - %false_3365 = torch.constant.bool false - %2925 = torch.aten.arange.start %int0_3360, %int128_3361, %none_3362, %none_3363, %cpu_3364, %false_3365 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_3366 = torch.constant.int 2 - %2926 = torch.aten.floor_divide.Scalar %2925, %int2_3366 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_3367 = torch.constant.int 6 - %2927 = torch.prims.convert_element_type %2926, %int6_3367 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_3368 = torch.constant.int 128 - %2928 = torch.aten.div.Scalar %2927, %int128_3368 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_3369 = torch.constant.float 2.000000e+00 - %2929 = torch.aten.mul.Scalar %2928, %float2.000000e00_3369 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_3370 = torch.constant.float 5.000000e+05 - %2930 = torch.aten.pow.Scalar %float5.000000e05_3370, %2929 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %2931 = torch.aten.reciprocal %2930 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_3371 = torch.constant.float 1.000000e+00 - %2932 = torch.aten.mul.Scalar %2931, %float1.000000e00_3371 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_3372 = torch.constant.int 1 - %2933 = torch.aten.unsqueeze %2924, %int1_3372 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_3373 = torch.constant.int 0 - %2934 = torch.aten.unsqueeze %2932, %int0_3373 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %2935 = torch.aten.mul.Tensor %2933, %2934 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_3374 = torch.constant.int 1 - %2936 = torch.aten.size.int %2882, %int1_3374 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_3375 = torch.constant.int 0 - %2937 = torch.aten.add.int %int0_3375, %2936 : !torch.int, !torch.int -> !torch.int - %int0_3376 = torch.constant.int 0 - %int0_3377 = torch.constant.int 0 - %int1_3378 = torch.constant.int 1 - %2938 = torch.aten.slice.Tensor %2935, %int0_3376, %int0_3377, %2937, %int1_3378 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2938, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_3379 = torch.constant.int 1 - %int0_3380 = torch.constant.int 0 - %int9223372036854775807_3381 = torch.constant.int 9223372036854775807 - %int1_3382 = torch.constant.int 1 - %2939 = torch.aten.slice.Tensor %2938, %int1_3379, %int0_3380, %int9223372036854775807_3381, %int1_3382 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2939, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_3383 = torch.constant.int 1 - %int0_3384 = torch.constant.int 0 - %int9223372036854775807_3385 = torch.constant.int 9223372036854775807 - %int1_3386 = torch.constant.int 1 - %2940 = torch.aten.slice.Tensor %2939, %int1_3383, %int0_3384, %int9223372036854775807_3385, %int1_3386 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %2940, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_3387 = torch.constant.int 0 - %2941 = torch.aten.unsqueeze %2940, %int0_3387 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2941, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_3388 = torch.constant.int 1 - %int0_3389 = torch.constant.int 0 - %int9223372036854775807_3390 = torch.constant.int 9223372036854775807 - %int1_3391 = torch.constant.int 1 - %2942 = torch.aten.slice.Tensor %2941, %int1_3388, %int0_3389, %int9223372036854775807_3390, %int1_3391 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2942, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_3392 = torch.constant.int 2 - %int0_3393 = torch.constant.int 0 - %int9223372036854775807_3394 = torch.constant.int 9223372036854775807 - %int1_3395 = torch.constant.int 1 - %2943 = torch.aten.slice.Tensor %2942, %int2_3392, %int0_3393, %int9223372036854775807_3394, %int1_3395 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %2943, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_3396 = torch.constant.int 4 - %int1_3397 = torch.constant.int 1 + %2912 = torch.aten.add.Scalar %2911, %int0_3350, %int1_3351 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2912, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %2913 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %2914 = torch.aten.view %2912, %2913 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %2914, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_3352 = torch.constant.int 4 + %int32_3353 = torch.constant.int 32 + %int8_3354 = torch.constant.int 8 + %int128_3355 = torch.constant.int 128 + %2915 = torch.prim.ListConstruct %int4_3352, %296, %int32_3353, %int8_3354, %int128_3355 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2916 = torch.aten.view %2908, %2915 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2916, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_3356 = torch.constant.int 32 + %int8_3357 = torch.constant.int 8 + %int128_3358 = torch.constant.int 128 + %2917 = torch.prim.ListConstruct %504, %int32_3356, %int8_3357, %int128_3358 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2918 = torch.aten.view %2916, %2917 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %2918, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_3359 = torch.constant.int 1 + %int2_3360 = torch.constant.int 2 + %2919 = torch.aten.transpose.int %2918, %int1_3359, %int2_3360 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2919, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_3361 = torch.constant.int 5 + %2920 = torch.prims.convert_element_type %2919, %int5_3361 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2920, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_3362 = torch.constant.int 32 + %int2_3363 = torch.constant.int 2 + %int8_3364 = torch.constant.int 8 + %int32_3365 = torch.constant.int 32 + %int128_3366 = torch.constant.int 128 + %2921 = torch.prim.ListConstruct %297, %int32_3362, %int2_3363, %int8_3364, %int32_3365, %int128_3366 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2922 = torch.aten.view %2684, %2921 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2922, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_3367 = torch.constant.int 8 + %int32_3368 = torch.constant.int 32 + %int128_3369 = torch.constant.int 128 + %2923 = torch.prim.ListConstruct %497, %int8_3367, %int32_3368, %int128_3369 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2924 = torch.aten.view %2922, %2923 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2924, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %2925 = torch.prim.ListConstruct %2914 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_3370 = torch.constant.bool false + %2926 = torch.aten.index_put %2924, %2925, %2920, %false_3370 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2926, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_3371 = torch.constant.int 32 + %int2_3372 = torch.constant.int 2 + %int8_3373 = torch.constant.int 8 + %int32_3374 = torch.constant.int 32 + %int128_3375 = torch.constant.int 128 + %2927 = torch.prim.ListConstruct %297, %int32_3371, %int2_3372, %int8_3373, %int32_3374, %int128_3375 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2928 = torch.aten.view %2926, %2927 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2928, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_3376 = torch.constant.int 2097152 + %2929 = torch.prim.ListConstruct %297, %int2097152_3376 : (!torch.int, !torch.int) -> !torch.list + %2930 = torch.aten.view %2928, %2929 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2930, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_3377 = torch.constant.int 32 + %int2_3378 = torch.constant.int 2 + %int8_3379 = torch.constant.int 8 + %int32_3380 = torch.constant.int 32 + %int128_3381 = torch.constant.int 128 + %2931 = torch.prim.ListConstruct %297, %int32_3377, %int2_3378, %int8_3379, %int32_3380, %int128_3381 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2932 = torch.aten.view %2930, %2931 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2932, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_3382 = torch.constant.int 8 + %int32_3383 = torch.constant.int 32 + %int128_3384 = torch.constant.int 128 + %2933 = torch.prim.ListConstruct %497, %int8_3382, %int32_3383, %int128_3384 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2934 = torch.aten.view %2932, %2933 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2934, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_3385 = torch.constant.int 32 + %2935 = torch.aten.mul.Scalar %arg2, %int32_3385 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2935, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int9_3386 = torch.constant.int 9 + %int1_3387 = torch.constant.int 1 + %2936 = torch.aten.add.Scalar %2935, %int9_3386, %int1_3387 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2936, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_3388 = torch.constant.int 2 + %2937 = torch.aten.mul.Scalar %2936, %int2_3388 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2937, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_3389 = torch.constant.int 1 + %int1_3390 = torch.constant.int 1 + %2938 = torch.aten.add.Scalar %2937, %int1_3389, %int1_3390 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %2938, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %2939 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %2940 = torch.aten.view %2938, %2939 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %2940, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_3391 = torch.constant.int 4 + %int32_3392 = torch.constant.int 32 + %int8_3393 = torch.constant.int 8 + %int128_3394 = torch.constant.int 128 + %2941 = torch.prim.ListConstruct %int4_3391, %296, %int32_3392, %int8_3393, %int128_3394 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2942 = torch.aten.view %2782, %2941 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2942, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_3395 = torch.constant.int 32 + %int8_3396 = torch.constant.int 8 + %int128_3397 = torch.constant.int 128 + %2943 = torch.prim.ListConstruct %504, %int32_3395, %int8_3396, %int128_3397 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2944 = torch.aten.view %2942, %2943 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %2944, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> %int1_3398 = torch.constant.int 1 - %2944 = torch.prim.ListConstruct %int4_3396, %int1_3397, %int1_3398 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2945 = torch.aten.repeat %2943, %2944 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %2945, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_3399 = torch.constant.int 6 - %2946 = torch.prims.convert_element_type %2893, %int6_3399 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %2946, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %2947 = torch_c.to_builtin_tensor %2946 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %2948 = torch_c.to_builtin_tensor %2945 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %2949 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%2947, %2948) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %2950 = torch_c.from_builtin_tensor %2949 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %2950, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> + %int2_3399 = torch.constant.int 2 + %2945 = torch.aten.transpose.int %2944, %int1_3398, %int2_3399 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2945, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> %int5_3400 = torch.constant.int 5 - %2951 = torch.prims.convert_element_type %2950, %int5_3400 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2951, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_3401 = torch.constant.int 64 - %2952 = torch.aten.mul.Scalar %arg2, %int64_3401 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2952, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int26 = torch.constant.int 26 - %int1_3402 = torch.constant.int 1 - %2953 = torch.aten.add.Scalar %2952, %int26, %int1_3402 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2953, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_3403 = torch.constant.int 4 - %int32_3404 = torch.constant.int 32 - %int8_3405 = torch.constant.int 8 + %2946 = torch.prims.convert_element_type %2945, %int5_3400 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2946, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %2947 = torch.prim.ListConstruct %2940 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_3401 = torch.constant.bool false + %2948 = torch.aten.index_put %2934, %2947, %2946, %false_3401 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %2948, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_3402 = torch.constant.int 32 + %int2_3403 = torch.constant.int 2 + %int8_3404 = torch.constant.int 8 + %int32_3405 = torch.constant.int 32 %int128_3406 = torch.constant.int 128 - %2954 = torch.prim.ListConstruct %int4_3403, %398, %int32_3404, %int8_3405, %int128_3406 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2955 = torch.aten.view %2951, %2954 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2955, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_3407 = torch.constant.int 4 - %2956 = torch.aten.mul.int %int4_3407, %398 : !torch.int, !torch.int -> !torch.int - %int32_3408 = torch.constant.int 32 - %int8_3409 = torch.constant.int 8 - %int128_3410 = torch.constant.int 128 - %2957 = torch.prim.ListConstruct %2956, %int32_3408, %int8_3409, %int128_3410 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2958 = torch.aten.view %2955, %2957 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2958, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %2949 = torch.prim.ListConstruct %297, %int32_3402, %int2_3403, %int8_3404, %int32_3405, %int128_3406 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2950 = torch.aten.view %2948, %2949 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2950, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_3407 = torch.constant.int 2097152 + %2951 = torch.prim.ListConstruct %297, %int2097152_3407 : (!torch.int, !torch.int) -> !torch.list + %2952 = torch.aten.view %2950, %2951 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2952, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_3408 = torch.constant.int -2 + %2953 = torch.aten.unsqueeze %2908, %int-2_3408 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2953, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_3409 = torch.constant.int 4 + %int8_3410 = torch.constant.int 8 %int4_3411 = torch.constant.int 4 - %2959 = torch.aten.mul.int %int4_3411, %398 : !torch.int, !torch.int -> !torch.int - %2960 = torch.prim.ListConstruct %2959 : (!torch.int) -> !torch.list - %2961 = torch.aten.view %2953, %2960 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2961, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_3412 = torch.constant.int 32 - %int2_3413 = torch.constant.int 2 - %int32_3414 = torch.constant.int 32 - %int8_3415 = torch.constant.int 8 - %int128_3416 = torch.constant.int 128 - %2962 = torch.prim.ListConstruct %389, %int32_3412, %int2_3413, %int32_3414, %int8_3415, %int128_3416 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2963 = torch.aten.view %2795, %2962 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2963, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_3417 = torch.constant.int 32 - %2964 = torch.aten.mul.int %389, %int32_3417 : !torch.int, !torch.int -> !torch.int - %int2_3418 = torch.constant.int 2 - %2965 = torch.aten.mul.int %2964, %int2_3418 : !torch.int, !torch.int -> !torch.int - %int32_3419 = torch.constant.int 32 + %int128_3412 = torch.constant.int 128 + %2954 = torch.prim.ListConstruct %int4_3409, %298, %int8_3410, %int4_3411, %int128_3412 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_3413 = torch.constant.bool false + %2955 = torch.aten.expand %2953, %2954, %false_3413 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2955, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_3414 = torch.constant.int 0 + %2956 = torch.aten.clone %2955, %int0_3414 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2956, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_3415 = torch.constant.int 4 + %int32_3416 = torch.constant.int 32 + %int128_3417 = torch.constant.int 128 + %2957 = torch.prim.ListConstruct %int4_3415, %298, %int32_3416, %int128_3417 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2958 = torch.aten._unsafe_view %2956, %2957 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2958, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_3418 = torch.constant.int -2 + %2959 = torch.aten.unsqueeze %2782, %int-2_3418 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2959, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_3419 = torch.constant.int 4 %int8_3420 = torch.constant.int 8 - %int128_3421 = torch.constant.int 128 - %2966 = torch.prim.ListConstruct %2965, %int32_3419, %int8_3420, %int128_3421 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2967 = torch.aten.view %2963, %2966 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2967, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %2968 = torch.prim.ListConstruct %2961 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_3422 = torch.constant.bool false - %2969 = torch.aten.index_put %2967, %2968, %2958, %false_3422 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2969, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_3423 = torch.constant.int 32 - %int2_3424 = torch.constant.int 2 - %int32_3425 = torch.constant.int 32 - %int8_3426 = torch.constant.int 8 + %int4_3421 = torch.constant.int 4 + %int128_3422 = torch.constant.int 128 + %2960 = torch.prim.ListConstruct %int4_3419, %298, %int8_3420, %int4_3421, %int128_3422 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_3423 = torch.constant.bool false + %2961 = torch.aten.expand %2959, %2960, %false_3423 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2961, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_3424 = torch.constant.int 0 + %2962 = torch.aten.clone %2961, %int0_3424 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2962, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_3425 = torch.constant.int 4 + %int32_3426 = torch.constant.int 32 %int128_3427 = torch.constant.int 128 - %2970 = torch.prim.ListConstruct %389, %int32_3423, %int2_3424, %int32_3425, %int8_3426, %int128_3427 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2971 = torch.aten.view %2969, %2970 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2971, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3428 = torch.constant.int 2097152 - %2972 = torch.prim.ListConstruct %389, %int2097152_3428 : (!torch.int, !torch.int) -> !torch.list - %2973 = torch.aten.view %2971, %2972 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2973, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_3429 = torch.constant.int 32 - %int2_3430 = torch.constant.int 2 - %int32_3431 = torch.constant.int 32 - %int8_3432 = torch.constant.int 8 - %int128_3433 = torch.constant.int 128 - %2974 = torch.prim.ListConstruct %389, %int32_3429, %int2_3430, %int32_3431, %int8_3432, %int128_3433 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2975 = torch.aten.view %2973, %2974 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2975, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_3434 = torch.constant.int 32 - %int8_3435 = torch.constant.int 8 - %int128_3436 = torch.constant.int 128 - %2976 = torch.prim.ListConstruct %2965, %int32_3434, %int8_3435, %int128_3436 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2977 = torch.aten.view %2975, %2976 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2977, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_3437 = torch.constant.int 4 - %int32_3438 = torch.constant.int 32 - %int8_3439 = torch.constant.int 8 - %int128_3440 = torch.constant.int 128 - %2978 = torch.prim.ListConstruct %int4_3437, %398, %int32_3438, %int8_3439, %int128_3440 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2979 = torch.aten.view %2895, %2978 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2979, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_3441 = torch.constant.int 4 - %2980 = torch.aten.mul.int %int4_3441, %398 : !torch.int, !torch.int -> !torch.int - %int32_3442 = torch.constant.int 32 - %int8_3443 = torch.constant.int 8 - %int128_3444 = torch.constant.int 128 - %2981 = torch.prim.ListConstruct %2980, %int32_3442, %int8_3443, %int128_3444 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2982 = torch.aten.view %2979, %2981 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2982, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_3445 = torch.constant.int 1 - %int1_3446 = torch.constant.int 1 - %2983 = torch.aten.add.Scalar %2953, %int1_3445, %int1_3446 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2983, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_3447 = torch.constant.int 4 - %2984 = torch.aten.mul.int %int4_3447, %398 : !torch.int, !torch.int -> !torch.int - %2985 = torch.prim.ListConstruct %2984 : (!torch.int) -> !torch.list - %2986 = torch.aten.view %2983, %2985 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2986, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %2987 = torch.prim.ListConstruct %2986 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_3448 = torch.constant.bool false - %2988 = torch.aten.index_put %2977, %2987, %2982, %false_3448 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %2988, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_3449 = torch.constant.int 32 - %int2_3450 = torch.constant.int 2 - %int32_3451 = torch.constant.int 32 - %int8_3452 = torch.constant.int 8 - %int128_3453 = torch.constant.int 128 - %2989 = torch.prim.ListConstruct %389, %int32_3449, %int2_3450, %int32_3451, %int8_3452, %int128_3453 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2990 = torch.aten.view %2988, %2989 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2990, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3454 = torch.constant.int 2097152 - %2991 = torch.prim.ListConstruct %389, %int2097152_3454 : (!torch.int, !torch.int) -> !torch.list - %2992 = torch.aten.view %2990, %2991 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2992, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_3455 = torch.constant.int -2 - %2993 = torch.aten.unsqueeze %2951, %int-2_3455 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2993, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_3456 = torch.constant.int 4 - %int8_3457 = torch.constant.int 8 - %int4_3458 = torch.constant.int 4 - %int128_3459 = torch.constant.int 128 - %2994 = torch.prim.ListConstruct %int4_3456, %2936, %int8_3457, %int4_3458, %int128_3459 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_3460 = torch.constant.bool false - %2995 = torch.aten.expand %2993, %2994, %false_3460 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2995, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_3461 = torch.constant.int 0 - %2996 = torch.aten.clone %2995, %int0_3461 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2996, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_3462 = torch.constant.int 4 - %int32_3463 = torch.constant.int 32 - %int128_3464 = torch.constant.int 128 - %2997 = torch.prim.ListConstruct %int4_3462, %2936, %int32_3463, %int128_3464 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2998 = torch.aten._unsafe_view %2996, %2997 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2998, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_3465 = torch.constant.int -2 - %2999 = torch.aten.unsqueeze %2895, %int-2_3465 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2999, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_3466 = torch.constant.int 1 - %3000 = torch.aten.size.int %2889, %int1_3466 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int + %2963 = torch.prim.ListConstruct %int4_3425, %298, %int32_3426, %int128_3427 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2964 = torch.aten._unsafe_view %2962, %2963 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2964, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_3428 = torch.constant.int 1 + %int2_3429 = torch.constant.int 2 + %2965 = torch.aten.transpose.int %2845, %int1_3428, %int2_3429 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2965, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_3430 = torch.constant.int 1 + %int2_3431 = torch.constant.int 2 + %2966 = torch.aten.transpose.int %2958, %int1_3430, %int2_3431 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2966, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_3432 = torch.constant.int 1 + %int2_3433 = torch.constant.int 2 + %2967 = torch.aten.transpose.int %2964, %int1_3432, %int2_3433 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2967, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_3434 = torch.constant.float 0.000000e+00 + %false_3435 = torch.constant.bool false + %none_3436 = torch.constant.none + %2968:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2965, %2966, %2967, %float0.000000e00_3434, %false_3435, %327, %none_3436) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %2968#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_3437 = torch.constant.int 1 + %int2_3438 = torch.constant.int 2 + %2969 = torch.aten.transpose.int %2968#0, %int1_3437, %int2_3438 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2969, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_3439 = torch.constant.int 4 + %int4096_3440 = torch.constant.int 4096 + %2970 = torch.prim.ListConstruct %int4_3439, %298, %int4096_3440 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2971 = torch.aten.view %2969, %2970 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2971, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_3441 = torch.constant.int -2 + %int-1_3442 = torch.constant.int -1 + %2972 = torch.aten.transpose.int %87, %int-2_3441, %int-1_3442 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_3443 = torch.constant.int 5 + %2973 = torch.prims.convert_element_type %2972, %int5_3443 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_3444 = torch.constant.int 4096 + %2974 = torch.prim.ListConstruct %342, %int4096_3444 : (!torch.int, !torch.int) -> !torch.list + %2975 = torch.aten.view %2971, %2974 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2975, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2976 = torch.aten.mm %2975, %2973 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2976, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_3445 = torch.constant.int 4 + %int4096_3446 = torch.constant.int 4096 + %2977 = torch.prim.ListConstruct %int4_3445, %298, %int4096_3446 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2978 = torch.aten.view %2976, %2977 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2978, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_3447 = torch.constant.int 1 + %2979 = torch.aten.add.Tensor %2745, %2978, %int1_3447 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2979, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_3448 = torch.constant.int 6 + %2980 = torch.prims.convert_element_type %2979, %int6_3448 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2980, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_3449 = torch.constant.int 2 + %2981 = torch.aten.pow.Tensor_Scalar %2980, %int2_3449 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2981, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_3450 = torch.constant.int -1 + %2982 = torch.prim.ListConstruct %int-1_3450 : (!torch.int) -> !torch.list + %true_3451 = torch.constant.bool true + %none_3452 = torch.constant.none + %2983 = torch.aten.mean.dim %2981, %2982, %true_3451, %none_3452 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2983, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_3453 = torch.constant.float 9.9999997473787516E-6 + %int1_3454 = torch.constant.int 1 + %2984 = torch.aten.add.Scalar %2983, %float9.999990e-06_3453, %int1_3454 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2984, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %2985 = torch.aten.rsqrt %2984 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %2985, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %2986 = torch.aten.mul.Tensor %2980, %2985 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2986, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_3455 = torch.constant.int 5 + %2987 = torch.prims.convert_element_type %2986, %int5_3455 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2987, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %2988 = torch.aten.mul.Tensor %88, %2987 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %2988, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_3456 = torch.constant.int 5 + %2989 = torch.prims.convert_element_type %2988, %int5_3456 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %2989, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_3457 = torch.constant.int -2 + %int-1_3458 = torch.constant.int -1 + %2990 = torch.aten.transpose.int %89, %int-2_3457, %int-1_3458 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3459 = torch.constant.int 5 + %2991 = torch.prims.convert_element_type %2990, %int5_3459 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_3460 = torch.constant.int 4096 + %2992 = torch.prim.ListConstruct %342, %int4096_3460 : (!torch.int, !torch.int) -> !torch.list + %2993 = torch.aten.view %2989, %2992 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %2993, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %2994 = torch.aten.mm %2993, %2991 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %2994, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_3461 = torch.constant.int 4 + %int14336_3462 = torch.constant.int 14336 + %2995 = torch.prim.ListConstruct %int4_3461, %298, %int14336_3462 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2996 = torch.aten.view %2994, %2995 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %2996, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %2997 = torch.aten.silu %2996 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %2997, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_3463 = torch.constant.int -2 + %int-1_3464 = torch.constant.int -1 + %2998 = torch.aten.transpose.int %90, %int-2_3463, %int-1_3464 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3465 = torch.constant.int 5 + %2999 = torch.prims.convert_element_type %2998, %int5_3465 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_3466 = torch.constant.int 4096 + %3000 = torch.prim.ListConstruct %342, %int4096_3466 : (!torch.int, !torch.int) -> !torch.list + %3001 = torch.aten.view %2989, %3000 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3001, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3002 = torch.aten.mm %3001, %2999 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %3002, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> %int4_3467 = torch.constant.int 4 - %int8_3468 = torch.constant.int 8 - %int4_3469 = torch.constant.int 4 - %int128_3470 = torch.constant.int 128 - %3001 = torch.prim.ListConstruct %int4_3467, %3000, %int8_3468, %int4_3469, %int128_3470 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_3471 = torch.constant.bool false - %3002 = torch.aten.expand %2999, %3001, %false_3471 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3002, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_3472 = torch.constant.int 0 - %3003 = torch.aten.clone %3002, %int0_3472 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3003, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int14336_3468 = torch.constant.int 14336 + %3003 = torch.prim.ListConstruct %int4_3467, %298, %int14336_3468 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3004 = torch.aten.view %3002, %3003 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %3004, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %3005 = torch.aten.mul.Tensor %2997, %3004 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %3005, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_3469 = torch.constant.int -2 + %int-1_3470 = torch.constant.int -1 + %3006 = torch.aten.transpose.int %91, %int-2_3469, %int-1_3470 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_3471 = torch.constant.int 5 + %3007 = torch.prims.convert_element_type %3006, %int5_3471 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_3472 = torch.constant.int 14336 + %3008 = torch.prim.ListConstruct %342, %int14336_3472 : (!torch.int, !torch.int) -> !torch.list + %3009 = torch.aten.view %3005, %3008 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %3009, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %3010 = torch.aten.mm %3009, %3007 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3010, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> %int4_3473 = torch.constant.int 4 - %int32_3474 = torch.constant.int 32 - %int128_3475 = torch.constant.int 128 - %3004 = torch.prim.ListConstruct %int4_3473, %3000, %int32_3474, %int128_3475 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3005 = torch.aten._unsafe_view %3003, %3004 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3005, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_3476 = torch.constant.int 1 + %int4096_3474 = torch.constant.int 4096 + %3011 = torch.prim.ListConstruct %int4_3473, %298, %int4096_3474 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3012 = torch.aten.view %3010, %3011 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3012, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_3475 = torch.constant.int 1 + %3013 = torch.aten.add.Tensor %2979, %3012, %int1_3475 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3013, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_3476 = torch.constant.int 6 + %3014 = torch.prims.convert_element_type %3013, %int6_3476 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3014, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> %int2_3477 = torch.constant.int 2 - %3006 = torch.aten.transpose.int %2923, %int1_3476, %int2_3477 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3006, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_3478 = torch.constant.int 1 - %int2_3479 = torch.constant.int 2 - %3007 = torch.aten.transpose.int %2998, %int1_3478, %int2_3479 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3007, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_3480 = torch.constant.int 1 - %int2_3481 = torch.constant.int 2 - %3008 = torch.aten.transpose.int %3005, %int1_3480, %int2_3481 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3008, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_3482 = torch.constant.float 0.000000e+00 - %true_3483 = torch.constant.bool true - %none_3484 = torch.constant.none - %none_3485 = torch.constant.none - %3009:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3006, %3007, %3008, %float0.000000e00_3482, %true_3483, %none_3484, %none_3485) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %3009#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_3486 = torch.constant.int 1 - %int2_3487 = torch.constant.int 2 - %3010 = torch.aten.transpose.int %3009#0, %int1_3486, %int2_3487 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3010, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_3488 = torch.constant.int 4 - %int4096_3489 = torch.constant.int 4096 - %3011 = torch.prim.ListConstruct %int4_3488, %2908, %int4096_3489 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3012 = torch.aten.view %3010, %3011 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3012, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3490 = torch.constant.int -2 - %int-1_3491 = torch.constant.int -1 - %3013 = torch.aten.transpose.int %122, %int-2_3490, %int-1_3491 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3492 = torch.constant.int 4 - %3014 = torch.aten.mul.int %int4_3492, %2908 : !torch.int, !torch.int -> !torch.int - %int4096_3493 = torch.constant.int 4096 - %3015 = torch.prim.ListConstruct %3014, %int4096_3493 : (!torch.int, !torch.int) -> !torch.list - %3016 = torch.aten.view %3012, %3015 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3016, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3017 = torch.aten.mm %3016, %3013 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3017, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_3494 = torch.constant.int 4 - %int4096_3495 = torch.constant.int 4096 - %3018 = torch.prim.ListConstruct %int4_3494, %2908, %int4096_3495 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3019 = torch.aten.view %3017, %3018 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3019, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_3496 = torch.constant.int 1 - %3020 = torch.aten.add.Tensor %2858, %3019, %int1_3496 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3020, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_3497 = torch.constant.int 6 - %3021 = torch.prims.convert_element_type %3020, %int6_3497 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3021, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_3498 = torch.constant.int 2 - %3022 = torch.aten.pow.Tensor_Scalar %3021, %int2_3498 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3022, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_3499 = torch.constant.int -1 - %3023 = torch.prim.ListConstruct %int-1_3499 : (!torch.int) -> !torch.list - %true_3500 = torch.constant.bool true - %none_3501 = torch.constant.none - %3024 = torch.aten.mean.dim %3022, %3023, %true_3500, %none_3501 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3024, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_3502 = torch.constant.float 9.9999997473787516E-6 - %int1_3503 = torch.constant.int 1 - %3025 = torch.aten.add.Scalar %3024, %float9.999990e-06_3502, %int1_3503 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3025, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3026 = torch.aten.rsqrt %3025 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3026, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3027 = torch.aten.mul.Tensor %3021, %3026 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3027, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_3504 = torch.constant.int 5 - %3028 = torch.prims.convert_element_type %3027, %int5_3504 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3028, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %3029 = torch.aten.mul.Tensor %123, %3028 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3029, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_3505 = torch.constant.int 5 - %3030 = torch.prims.convert_element_type %3029, %int5_3505 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3030, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3506 = torch.constant.int -2 - %int-1_3507 = torch.constant.int -1 - %3031 = torch.aten.transpose.int %124, %int-2_3506, %int-1_3507 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3508 = torch.constant.int 4 - %3032 = torch.aten.mul.int %int4_3508, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3509 = torch.constant.int 4096 - %3033 = torch.prim.ListConstruct %3032, %int4096_3509 : (!torch.int, !torch.int) -> !torch.list - %3034 = torch.aten.view %3030, %3033 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3034, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3035 = torch.aten.mm %3034, %3031 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3035, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_3510 = torch.constant.int 4 - %int14336_3511 = torch.constant.int 14336 - %3036 = torch.prim.ListConstruct %int4_3510, %306, %int14336_3511 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3037 = torch.aten.view %3035, %3036 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3037, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %3038 = torch.aten.silu %3037 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3038, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_3512 = torch.constant.int -2 - %int-1_3513 = torch.constant.int -1 - %3039 = torch.aten.transpose.int %125, %int-2_3512, %int-1_3513 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3514 = torch.constant.int 4 - %3040 = torch.aten.mul.int %int4_3514, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3515 = torch.constant.int 4096 - %3041 = torch.prim.ListConstruct %3040, %int4096_3515 : (!torch.int, !torch.int) -> !torch.list - %3042 = torch.aten.view %3030, %3041 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3042, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3043 = torch.aten.mm %3042, %3039 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3043, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_3516 = torch.constant.int 4 - %int14336_3517 = torch.constant.int 14336 - %3044 = torch.prim.ListConstruct %int4_3516, %306, %int14336_3517 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3045 = torch.aten.view %3043, %3044 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3045, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %3046 = torch.aten.mul.Tensor %3038, %3045 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3046, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_3518 = torch.constant.int -2 - %int-1_3519 = torch.constant.int -1 - %3047 = torch.aten.transpose.int %126, %int-2_3518, %int-1_3519 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_3520 = torch.constant.int 1 - %3048 = torch.aten.size.int %3037, %int1_3520 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_3521 = torch.constant.int 4 - %3049 = torch.aten.mul.int %int4_3521, %3048 : !torch.int, !torch.int -> !torch.int - %int14336_3522 = torch.constant.int 14336 - %3050 = torch.prim.ListConstruct %3049, %int14336_3522 : (!torch.int, !torch.int) -> !torch.list - %3051 = torch.aten.view %3046, %3050 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3051, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %3052 = torch.aten.mm %3051, %3047 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3052, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_3523 = torch.constant.int 4 - %int4096_3524 = torch.constant.int 4096 - %3053 = torch.prim.ListConstruct %int4_3523, %3048, %int4096_3524 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3054 = torch.aten.view %3052, %3053 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3054, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_3525 = torch.constant.int 1 - %3055 = torch.aten.add.Tensor %3020, %3054, %int1_3525 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3055, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_3526 = torch.constant.int 6 - %3056 = torch.prims.convert_element_type %3055, %int6_3526 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3056, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_3527 = torch.constant.int 2 - %3057 = torch.aten.pow.Tensor_Scalar %3056, %int2_3527 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3057, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_3528 = torch.constant.int -1 - %3058 = torch.prim.ListConstruct %int-1_3528 : (!torch.int) -> !torch.list - %true_3529 = torch.constant.bool true - %none_3530 = torch.constant.none - %3059 = torch.aten.mean.dim %3057, %3058, %true_3529, %none_3530 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3059, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_3531 = torch.constant.float 9.9999997473787516E-6 + %3015 = torch.aten.pow.Tensor_Scalar %3014, %int2_3477 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3015, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_3478 = torch.constant.int -1 + %3016 = torch.prim.ListConstruct %int-1_3478 : (!torch.int) -> !torch.list + %true_3479 = torch.constant.bool true + %none_3480 = torch.constant.none + %3017 = torch.aten.mean.dim %3015, %3016, %true_3479, %none_3480 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3017, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_3481 = torch.constant.float 9.9999997473787516E-6 + %int1_3482 = torch.constant.int 1 + %3018 = torch.aten.add.Scalar %3017, %float9.999990e-06_3481, %int1_3482 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3018, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %3019 = torch.aten.rsqrt %3018 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3019, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %3020 = torch.aten.mul.Tensor %3014, %3019 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3020, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_3483 = torch.constant.int 5 + %3021 = torch.prims.convert_element_type %3020, %int5_3483 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3021, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %3022 = torch.aten.mul.Tensor %92, %3021 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3022, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_3484 = torch.constant.int 5 + %3023 = torch.prims.convert_element_type %3022, %int5_3484 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3023, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_3485 = torch.constant.int -2 + %int-1_3486 = torch.constant.int -1 + %3024 = torch.aten.transpose.int %93, %int-2_3485, %int-1_3486 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_3487 = torch.constant.int 5 + %3025 = torch.prims.convert_element_type %3024, %int5_3487 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_3488 = torch.constant.int 4096 + %3026 = torch.prim.ListConstruct %342, %int4096_3488 : (!torch.int, !torch.int) -> !torch.list + %3027 = torch.aten.view %3023, %3026 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3027, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3028 = torch.aten.mm %3027, %3025 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3028, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_3489 = torch.constant.int 4 + %int4096_3490 = torch.constant.int 4096 + %3029 = torch.prim.ListConstruct %int4_3489, %298, %int4096_3490 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3030 = torch.aten.view %3028, %3029 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3030, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_3491 = torch.constant.int -2 + %int-1_3492 = torch.constant.int -1 + %3031 = torch.aten.transpose.int %94, %int-2_3491, %int-1_3492 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_3493 = torch.constant.int 5 + %3032 = torch.prims.convert_element_type %3031, %int5_3493 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_3494 = torch.constant.int 4096 + %3033 = torch.prim.ListConstruct %342, %int4096_3494 : (!torch.int, !torch.int) -> !torch.list + %3034 = torch.aten.view %3023, %3033 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3034, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3035 = torch.aten.mm %3034, %3032 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %3035, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_3495 = torch.constant.int 4 + %int1024_3496 = torch.constant.int 1024 + %3036 = torch.prim.ListConstruct %int4_3495, %298, %int1024_3496 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3037 = torch.aten.view %3035, %3036 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %3037, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_3497 = torch.constant.int -2 + %int-1_3498 = torch.constant.int -1 + %3038 = torch.aten.transpose.int %95, %int-2_3497, %int-1_3498 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_3499 = torch.constant.int 5 + %3039 = torch.prims.convert_element_type %3038, %int5_3499 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_3500 = torch.constant.int 4096 + %3040 = torch.prim.ListConstruct %342, %int4096_3500 : (!torch.int, !torch.int) -> !torch.list + %3041 = torch.aten.view %3023, %3040 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3041, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3042 = torch.aten.mm %3041, %3039 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %3042, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_3501 = torch.constant.int 4 + %int1024_3502 = torch.constant.int 1024 + %3043 = torch.prim.ListConstruct %int4_3501, %298, %int1024_3502 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3044 = torch.aten.view %3042, %3043 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %3044, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_3503 = torch.constant.int 4 + %int32_3504 = torch.constant.int 32 + %int128_3505 = torch.constant.int 128 + %3045 = torch.prim.ListConstruct %int4_3503, %298, %int32_3504, %int128_3505 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3046 = torch.aten.view %3030, %3045 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3046, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_3506 = torch.constant.int 4 + %int8_3507 = torch.constant.int 8 + %int128_3508 = torch.constant.int 128 + %3047 = torch.prim.ListConstruct %int4_3506, %298, %int8_3507, %int128_3508 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3048 = torch.aten.view %3037, %3047 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3048, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_3509 = torch.constant.int 4 + %int8_3510 = torch.constant.int 8 + %int128_3511 = torch.constant.int 128 + %3049 = torch.prim.ListConstruct %int4_3509, %298, %int8_3510, %int128_3511 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3050 = torch.aten.view %3044, %3049 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3050, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_3512 = torch.constant.int 131072 + %none_3513 = torch.constant.none + %none_3514 = torch.constant.none + %cpu_3515 = torch.constant.device "cpu" + %false_3516 = torch.constant.bool false + %3051 = torch.aten.arange %int131072_3512, %none_3513, %none_3514, %cpu_3515, %false_3516 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_3517 = torch.constant.int 0 + %int128_3518 = torch.constant.int 128 + %int2_3519 = torch.constant.int 2 + %int4_3520 = torch.constant.int 4 + %none_3521 = torch.constant.none + %cpu_3522 = torch.constant.device "cpu" + %false_3523 = torch.constant.bool false + %3052 = torch.aten.arange.start_step %int0_3517, %int128_3518, %int2_3519, %int4_3520, %none_3521, %cpu_3522, %false_3523 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_3524 = torch.constant.int 6 + %3053 = torch.prims.convert_element_type %3052, %int6_3524 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_3525 = torch.constant.int 128 + %3054 = torch.aten.div.Scalar %3053, %int128_3525 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_3526 = torch.constant.float 5.000000e+05 + %3055 = torch.aten.pow.Scalar %float5.000000e05_3526, %3054 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3056 = torch.aten.reciprocal %3055 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_3527 = torch.constant.float 1.000000e+00 + %3057 = torch.aten.mul.Scalar %3056, %float1.000000e00_3527 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %3058 = torch.aten.reciprocal %3057 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_3528 = torch.constant.float 6.2831853071795862 + %3059 = torch.aten.mul.Scalar %3058, %float6.283190e00_3528 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_3529 = torch.constant.float 8.192000e+03 + %3060 = torch.aten.gt.Scalar %3059, %float8.192000e03_3529 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_3530 = torch.constant.int 8 + %3061 = torch.aten.div.Scalar %3057, %int8_3530 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3062 = torch.aten.where.self %3060, %3061, %3057 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3063 = torch.aten.reciprocal %3059 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_3531 = torch.constant.int 8192 + %3064 = torch.aten.mul.Scalar %3063, %int8192_3531 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_3532 = torch.constant.int 1 - %3060 = torch.aten.add.Scalar %3059, %float9.999990e-06_3531, %int1_3532 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3060, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3061 = torch.aten.rsqrt %3060 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3061, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3062 = torch.aten.mul.Tensor %3056, %3061 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3062, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_3533 = torch.constant.int 5 - %3063 = torch.prims.convert_element_type %3062, %int5_3533 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3063, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %3064 = torch.aten.mul.Tensor %127, %3063 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3064, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_3534 = torch.constant.int 5 - %3065 = torch.prims.convert_element_type %3064, %int5_3534 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3065, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3535 = torch.constant.int -2 - %int-1_3536 = torch.constant.int -1 - %3066 = torch.aten.transpose.int %128, %int-2_3535, %int-1_3536 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3537 = torch.constant.int 4 - %3067 = torch.aten.mul.int %int4_3537, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3538 = torch.constant.int 4096 - %3068 = torch.prim.ListConstruct %3067, %int4096_3538 : (!torch.int, !torch.int) -> !torch.list - %3069 = torch.aten.view %3065, %3068 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3069, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3070 = torch.aten.mm %3069, %3066 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3070, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_3539 = torch.constant.int 4 - %int4096_3540 = torch.constant.int 4096 - %3071 = torch.prim.ListConstruct %int4_3539, %306, %int4096_3540 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3072 = torch.aten.view %3070, %3071 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3072, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3541 = torch.constant.int -2 - %int-1_3542 = torch.constant.int -1 - %3073 = torch.aten.transpose.int %129, %int-2_3541, %int-1_3542 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3543 = torch.constant.int 4 - %3074 = torch.aten.mul.int %int4_3543, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3544 = torch.constant.int 4096 - %3075 = torch.prim.ListConstruct %3074, %int4096_3544 : (!torch.int, !torch.int) -> !torch.list - %3076 = torch.aten.view %3065, %3075 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3076, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3077 = torch.aten.mm %3076, %3073 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %3077, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_3545 = torch.constant.int 4 - %int1024_3546 = torch.constant.int 1024 - %3078 = torch.prim.ListConstruct %int4_3545, %306, %int1024_3546 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3079 = torch.aten.view %3077, %3078 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %3079, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_3547 = torch.constant.int -2 - %int-1_3548 = torch.constant.int -1 - %3080 = torch.aten.transpose.int %130, %int-2_3547, %int-1_3548 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3549 = torch.constant.int 4 - %3081 = torch.aten.mul.int %int4_3549, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3550 = torch.constant.int 4096 - %3082 = torch.prim.ListConstruct %3081, %int4096_3550 : (!torch.int, !torch.int) -> !torch.list - %3083 = torch.aten.view %3065, %3082 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3083, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3084 = torch.aten.mm %3083, %3080 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %3084, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_3551 = torch.constant.int 4 - %int1024_3552 = torch.constant.int 1024 - %3085 = torch.prim.ListConstruct %int4_3551, %306, %int1024_3552 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3086 = torch.aten.view %3084, %3085 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %3086, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_3553 = torch.constant.int 4 - %int32_3554 = torch.constant.int 32 - %int128_3555 = torch.constant.int 128 - %3087 = torch.prim.ListConstruct %int4_3553, %306, %int32_3554, %int128_3555 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3088 = torch.aten.view %3072, %3087 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3088, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_3556 = torch.constant.int 4 - %int8_3557 = torch.constant.int 8 - %int128_3558 = torch.constant.int 128 - %3089 = torch.prim.ListConstruct %int4_3556, %306, %int8_3557, %int128_3558 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3090 = torch.aten.view %3079, %3089 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3090, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_3559 = torch.constant.int 4 - %int8_3560 = torch.constant.int 8 - %int128_3561 = torch.constant.int 128 - %3091 = torch.prim.ListConstruct %int4_3559, %306, %int8_3560, %int128_3561 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3092 = torch.aten.view %3086, %3091 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3092, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_3562 = torch.constant.int 131072 - %none_3563 = torch.constant.none - %none_3564 = torch.constant.none - %cpu_3565 = torch.constant.device "cpu" - %false_3566 = torch.constant.bool false - %3093 = torch.aten.arange %int131072_3562, %none_3563, %none_3564, %cpu_3565, %false_3566 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_3567 = torch.constant.int 0 - %int128_3568 = torch.constant.int 128 - %none_3569 = torch.constant.none - %none_3570 = torch.constant.none - %cpu_3571 = torch.constant.device "cpu" - %false_3572 = torch.constant.bool false - %3094 = torch.aten.arange.start %int0_3567, %int128_3568, %none_3569, %none_3570, %cpu_3571, %false_3572 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_3573 = torch.constant.int 2 - %3095 = torch.aten.floor_divide.Scalar %3094, %int2_3573 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_3574 = torch.constant.int 6 - %3096 = torch.prims.convert_element_type %3095, %int6_3574 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_3575 = torch.constant.int 128 - %3097 = torch.aten.div.Scalar %3096, %int128_3575 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_3576 = torch.constant.float 2.000000e+00 - %3098 = torch.aten.mul.Scalar %3097, %float2.000000e00_3576 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_3577 = torch.constant.float 5.000000e+05 - %3099 = torch.aten.pow.Scalar %float5.000000e05_3577, %3098 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %3100 = torch.aten.reciprocal %3099 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_3578 = torch.constant.float 1.000000e+00 - %3101 = torch.aten.mul.Scalar %3100, %float1.000000e00_3578 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_3579 = torch.constant.int 1 - %3102 = torch.aten.unsqueeze %3093, %int1_3579 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_3580 = torch.constant.int 0 - %3103 = torch.aten.unsqueeze %3101, %int0_3580 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %3104 = torch.aten.mul.Tensor %3102, %3103 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %int1_3533 = torch.constant.int 1 + %3065 = torch.aten.sub.Scalar %3064, %int1_3532, %int1_3533 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_3534 = torch.constant.int 3 + %3066 = torch.aten.div.Scalar %3065, %int3_3534 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_3535 = torch.constant.int 1 + %int1_3536 = torch.constant.int 1 + %3067 = torch.aten.rsub.Scalar %3066, %int1_3535, %int1_3536 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %3068 = torch.aten.mul.Tensor %3067, %3062 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_3537 = torch.constant.int 8 + %3069 = torch.aten.div.Scalar %3068, %int8_3537 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3070 = torch.aten.mul.Tensor %3066, %3062 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_3538 = torch.constant.int 1 + %3071 = torch.aten.add.Tensor %3069, %3070, %int1_3538 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_3539 = torch.constant.float 2.048000e+03 + %3072 = torch.aten.lt.Scalar %3059, %float2.048000e03_3539 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3073 = torch.aten.bitwise_not %3072 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_3540 = torch.constant.float 8.192000e+03 + %3074 = torch.aten.gt.Scalar %3059, %float8.192000e03_3540 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3075 = torch.aten.bitwise_not %3074 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3076 = torch.aten.mul.Tensor %3073, %3075 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3077 = torch.aten.where.self %3076, %3071, %3062 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3078 = torch.prim.ListConstruct %3077, %3077 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_3541 = torch.constant.int -1 + %3079 = torch.aten.cat %3078, %int-1_3541 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_3542 = torch.constant.int 6 + %3080 = torch.prims.convert_element_type %3079, %int6_3542 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_3543 = torch.constant.int 1 + %3081 = torch.aten.unsqueeze %3051, %int1_3543 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_3544 = torch.constant.int 6 + %3082 = torch.prims.convert_element_type %3081, %int6_3544 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_3545 = torch.constant.int 0 + %3083 = torch.aten.unsqueeze %3080, %int0_3545 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_3546 = torch.constant.int 6 + %3084 = torch.prims.convert_element_type %3083, %int6_3546 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %3085 = torch.aten.mul.Tensor %3082, %3084 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %3086 = torch.aten.cos %3085 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_3547 = torch.constant.int 5 + %3087 = torch.prims.convert_element_type %3086, %int5_3547 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %3088 = torch.aten.sin %3085 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_3548 = torch.constant.int 5 + %3089 = torch.prims.convert_element_type %3088, %int5_3548 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_3549 = torch.constant.int 0 + %int0_3550 = torch.constant.int 0 + %int1_3551 = torch.constant.int 1 + %3090 = torch.aten.slice.Tensor %3087, %int0_3549, %int0_3550, %298, %int1_3551 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3090, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_3552 = torch.constant.int 1 + %int0_3553 = torch.constant.int 0 + %int9223372036854775807_3554 = torch.constant.int 9223372036854775807 + %int1_3555 = torch.constant.int 1 + %3091 = torch.aten.slice.Tensor %3090, %int1_3552, %int0_3553, %int9223372036854775807_3554, %int1_3555 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3091, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_3556 = torch.constant.int 0 + %int0_3557 = torch.constant.int 0 + %int1_3558 = torch.constant.int 1 + %3092 = torch.aten.slice.Tensor %3089, %int0_3556, %int0_3557, %298, %int1_3558 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3092, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_3559 = torch.constant.int 1 + %int0_3560 = torch.constant.int 0 + %int9223372036854775807_3561 = torch.constant.int 9223372036854775807 + %int1_3562 = torch.constant.int 1 + %3093 = torch.aten.slice.Tensor %3092, %int1_3559, %int0_3560, %int9223372036854775807_3561, %int1_3562 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3093, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_3563 = torch.constant.int 0 + %3094 = torch.aten.unsqueeze %3091, %int0_3563 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3094, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_3564 = torch.constant.int 1 + %int0_3565 = torch.constant.int 0 + %int9223372036854775807_3566 = torch.constant.int 9223372036854775807 + %int1_3567 = torch.constant.int 1 + %3095 = torch.aten.slice.Tensor %3094, %int1_3564, %int0_3565, %int9223372036854775807_3566, %int1_3567 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3095, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_3568 = torch.constant.int 2 + %3096 = torch.aten.unsqueeze %3095, %int2_3568 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3096, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_3569 = torch.constant.int 3 + %int0_3570 = torch.constant.int 0 + %int9223372036854775807_3571 = torch.constant.int 9223372036854775807 + %int1_3572 = torch.constant.int 1 + %3097 = torch.aten.slice.Tensor %3096, %int3_3569, %int0_3570, %int9223372036854775807_3571, %int1_3572 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3097, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_3573 = torch.constant.int 4 + %int1_3574 = torch.constant.int 1 + %int1_3575 = torch.constant.int 1 + %int1_3576 = torch.constant.int 1 + %3098 = torch.prim.ListConstruct %int4_3573, %int1_3574, %int1_3575, %int1_3576 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3099 = torch.aten.repeat %3097, %3098 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3099, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_3577 = torch.constant.int 0 + %3100 = torch.aten.unsqueeze %3093, %int0_3577 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3100, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_3578 = torch.constant.int 1 + %int0_3579 = torch.constant.int 0 + %int9223372036854775807_3580 = torch.constant.int 9223372036854775807 %int1_3581 = torch.constant.int 1 - %3105 = torch.aten.size.int %3072, %int1_3581 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_3582 = torch.constant.int 0 - %3106 = torch.aten.add.int %int0_3582, %3105 : !torch.int, !torch.int -> !torch.int - %int0_3583 = torch.constant.int 0 + %3101 = torch.aten.slice.Tensor %3100, %int1_3578, %int0_3579, %int9223372036854775807_3580, %int1_3581 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3101, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_3582 = torch.constant.int 2 + %3102 = torch.aten.unsqueeze %3101, %int2_3582 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3102, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_3583 = torch.constant.int 3 %int0_3584 = torch.constant.int 0 - %int1_3585 = torch.constant.int 1 - %3107 = torch.aten.slice.Tensor %3104, %int0_3583, %int0_3584, %3106, %int1_3585 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3107, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %int9223372036854775807_3585 = torch.constant.int 9223372036854775807 %int1_3586 = torch.constant.int 1 - %int0_3587 = torch.constant.int 0 - %int9223372036854775807_3588 = torch.constant.int 9223372036854775807 + %3103 = torch.aten.slice.Tensor %3102, %int3_3583, %int0_3584, %int9223372036854775807_3585, %int1_3586 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3103, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_3587 = torch.constant.int 4 + %int1_3588 = torch.constant.int 1 %int1_3589 = torch.constant.int 1 - %3108 = torch.aten.slice.Tensor %3107, %int1_3586, %int0_3587, %int9223372036854775807_3588, %int1_3589 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3108, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> %int1_3590 = torch.constant.int 1 - %int0_3591 = torch.constant.int 0 - %int9223372036854775807_3592 = torch.constant.int 9223372036854775807 - %int1_3593 = torch.constant.int 1 - %3109 = torch.aten.slice.Tensor %3108, %int1_3590, %int0_3591, %int9223372036854775807_3592, %int1_3593 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3109, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_3594 = torch.constant.int 0 - %3110 = torch.aten.unsqueeze %3109, %int0_3594 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3110, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_3595 = torch.constant.int 1 - %int0_3596 = torch.constant.int 0 + %3104 = torch.prim.ListConstruct %int4_3587, %int1_3588, %int1_3589, %int1_3590 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3105 = torch.aten.repeat %3103, %3104 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3105, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %3106 = torch.aten.mul.Tensor %3046, %3099 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3106, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_3591 = torch.constant.int 3 + %int0_3592 = torch.constant.int 0 + %int64_3593 = torch.constant.int 64 + %int1_3594 = torch.constant.int 1 + %3107 = torch.aten.slice.Tensor %3046, %int3_3591, %int0_3592, %int64_3593, %int1_3594 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %3107, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_3595 = torch.constant.int 3 + %int64_3596 = torch.constant.int 64 %int9223372036854775807_3597 = torch.constant.int 9223372036854775807 %int1_3598 = torch.constant.int 1 - %3111 = torch.aten.slice.Tensor %3110, %int1_3595, %int0_3596, %int9223372036854775807_3597, %int1_3598 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3111, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_3599 = torch.constant.int 2 - %int0_3600 = torch.constant.int 0 - %int9223372036854775807_3601 = torch.constant.int 9223372036854775807 - %int1_3602 = torch.constant.int 1 - %3112 = torch.aten.slice.Tensor %3111, %int2_3599, %int0_3600, %int9223372036854775807_3601, %int1_3602 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3112, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_3603 = torch.constant.int 4 - %int1_3604 = torch.constant.int 1 - %int1_3605 = torch.constant.int 1 - %3113 = torch.prim.ListConstruct %int4_3603, %int1_3604, %int1_3605 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3114 = torch.aten.repeat %3112, %3113 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %3114, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_3606 = torch.constant.int 6 - %3115 = torch.prims.convert_element_type %3088, %int6_3606 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %3115, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %3116 = torch_c.to_builtin_tensor %3115 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %3117 = torch_c.to_builtin_tensor %3114 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %3118 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%3116, %3117) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %3119 = torch_c.from_builtin_tensor %3118 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %3119, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_3607 = torch.constant.int 5 - %3120 = torch.prims.convert_element_type %3119, %int5_3607 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3120, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_3608 = torch.constant.int 131072 - %none_3609 = torch.constant.none + %3108 = torch.aten.slice.Tensor %3046, %int3_3595, %int64_3596, %int9223372036854775807_3597, %int1_3598 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %3108, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %3109 = torch.aten.neg %3108 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %3109, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %3110 = torch.prim.ListConstruct %3109, %3107 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_3599 = torch.constant.int -1 + %3111 = torch.aten.cat %3110, %int-1_3599 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3111, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %3112 = torch.aten.mul.Tensor %3111, %3105 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3112, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_3600 = torch.constant.int 1 + %3113 = torch.aten.add.Tensor %3106, %3112, %int1_3600 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3113, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_3601 = torch.constant.int 131072 + %none_3602 = torch.constant.none + %none_3603 = torch.constant.none + %cpu_3604 = torch.constant.device "cpu" + %false_3605 = torch.constant.bool false + %3114 = torch.aten.arange %int131072_3601, %none_3602, %none_3603, %cpu_3604, %false_3605 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_3606 = torch.constant.int 0 + %int128_3607 = torch.constant.int 128 + %int2_3608 = torch.constant.int 2 + %int4_3609 = torch.constant.int 4 %none_3610 = torch.constant.none %cpu_3611 = torch.constant.device "cpu" %false_3612 = torch.constant.bool false - %3121 = torch.aten.arange %int131072_3608, %none_3609, %none_3610, %cpu_3611, %false_3612 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_3613 = torch.constant.int 0 + %3115 = torch.aten.arange.start_step %int0_3606, %int128_3607, %int2_3608, %int4_3609, %none_3610, %cpu_3611, %false_3612 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_3613 = torch.constant.int 6 + %3116 = torch.prims.convert_element_type %3115, %int6_3613 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> %int128_3614 = torch.constant.int 128 - %none_3615 = torch.constant.none - %none_3616 = torch.constant.none - %cpu_3617 = torch.constant.device "cpu" - %false_3618 = torch.constant.bool false - %3122 = torch.aten.arange.start %int0_3613, %int128_3614, %none_3615, %none_3616, %cpu_3617, %false_3618 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_3619 = torch.constant.int 2 - %3123 = torch.aten.floor_divide.Scalar %3122, %int2_3619 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_3620 = torch.constant.int 6 - %3124 = torch.prims.convert_element_type %3123, %int6_3620 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_3621 = torch.constant.int 128 - %3125 = torch.aten.div.Scalar %3124, %int128_3621 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_3622 = torch.constant.float 2.000000e+00 - %3126 = torch.aten.mul.Scalar %3125, %float2.000000e00_3622 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_3623 = torch.constant.float 5.000000e+05 - %3127 = torch.aten.pow.Scalar %float5.000000e05_3623, %3126 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %3128 = torch.aten.reciprocal %3127 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_3624 = torch.constant.float 1.000000e+00 - %3129 = torch.aten.mul.Scalar %3128, %float1.000000e00_3624 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %3117 = torch.aten.div.Scalar %3116, %int128_3614 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_3615 = torch.constant.float 5.000000e+05 + %3118 = torch.aten.pow.Scalar %float5.000000e05_3615, %3117 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3119 = torch.aten.reciprocal %3118 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_3616 = torch.constant.float 1.000000e+00 + %3120 = torch.aten.mul.Scalar %3119, %float1.000000e00_3616 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %3121 = torch.aten.reciprocal %3120 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_3617 = torch.constant.float 6.2831853071795862 + %3122 = torch.aten.mul.Scalar %3121, %float6.283190e00_3617 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_3618 = torch.constant.float 8.192000e+03 + %3123 = torch.aten.gt.Scalar %3122, %float8.192000e03_3618 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_3619 = torch.constant.int 8 + %3124 = torch.aten.div.Scalar %3120, %int8_3619 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3125 = torch.aten.where.self %3123, %3124, %3120 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3126 = torch.aten.reciprocal %3122 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_3620 = torch.constant.int 8192 + %3127 = torch.aten.mul.Scalar %3126, %int8192_3620 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_3621 = torch.constant.int 1 + %int1_3622 = torch.constant.int 1 + %3128 = torch.aten.sub.Scalar %3127, %int1_3621, %int1_3622 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_3623 = torch.constant.int 3 + %3129 = torch.aten.div.Scalar %3128, %int3_3623 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_3624 = torch.constant.int 1 %int1_3625 = torch.constant.int 1 - %3130 = torch.aten.unsqueeze %3121, %int1_3625 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_3626 = torch.constant.int 0 - %3131 = torch.aten.unsqueeze %3129, %int0_3626 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %3132 = torch.aten.mul.Tensor %3130, %3131 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %3130 = torch.aten.rsub.Scalar %3129, %int1_3624, %int1_3625 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %3131 = torch.aten.mul.Tensor %3130, %3125 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_3626 = torch.constant.int 8 + %3132 = torch.aten.div.Scalar %3131, %int8_3626 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3133 = torch.aten.mul.Tensor %3129, %3125 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> %int1_3627 = torch.constant.int 1 - %3133 = torch.aten.size.int %3079, %int1_3627 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_3628 = torch.constant.int 0 - %3134 = torch.aten.add.int %int0_3628, %3133 : !torch.int, !torch.int -> !torch.int - %int0_3629 = torch.constant.int 0 - %int0_3630 = torch.constant.int 0 - %int1_3631 = torch.constant.int 1 - %3135 = torch.aten.slice.Tensor %3132, %int0_3629, %int0_3630, %3134, %int1_3631 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3135, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %3134 = torch.aten.add.Tensor %3132, %3133, %int1_3627 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_3628 = torch.constant.float 2.048000e+03 + %3135 = torch.aten.lt.Scalar %3122, %float2.048000e03_3628 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3136 = torch.aten.bitwise_not %3135 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_3629 = torch.constant.float 8.192000e+03 + %3137 = torch.aten.gt.Scalar %3122, %float8.192000e03_3629 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3138 = torch.aten.bitwise_not %3137 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3139 = torch.aten.mul.Tensor %3136, %3138 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3140 = torch.aten.where.self %3139, %3134, %3125 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3141 = torch.prim.ListConstruct %3140, %3140 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_3630 = torch.constant.int -1 + %3142 = torch.aten.cat %3141, %int-1_3630 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_3631 = torch.constant.int 6 + %3143 = torch.prims.convert_element_type %3142, %int6_3631 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> %int1_3632 = torch.constant.int 1 - %int0_3633 = torch.constant.int 0 - %int9223372036854775807_3634 = torch.constant.int 9223372036854775807 - %int1_3635 = torch.constant.int 1 - %3136 = torch.aten.slice.Tensor %3135, %int1_3632, %int0_3633, %int9223372036854775807_3634, %int1_3635 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3136, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_3636 = torch.constant.int 1 - %int0_3637 = torch.constant.int 0 - %int9223372036854775807_3638 = torch.constant.int 9223372036854775807 - %int1_3639 = torch.constant.int 1 - %3137 = torch.aten.slice.Tensor %3136, %int1_3636, %int0_3637, %int9223372036854775807_3638, %int1_3639 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3137, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_3640 = torch.constant.int 0 - %3138 = torch.aten.unsqueeze %3137, %int0_3640 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3138, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %3144 = torch.aten.unsqueeze %3114, %int1_3632 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_3633 = torch.constant.int 6 + %3145 = torch.prims.convert_element_type %3144, %int6_3633 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_3634 = torch.constant.int 0 + %3146 = torch.aten.unsqueeze %3143, %int0_3634 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_3635 = torch.constant.int 6 + %3147 = torch.prims.convert_element_type %3146, %int6_3635 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %3148 = torch.aten.mul.Tensor %3145, %3147 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %3149 = torch.aten.cos %3148 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_3636 = torch.constant.int 5 + %3150 = torch.prims.convert_element_type %3149, %int5_3636 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %3151 = torch.aten.sin %3148 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_3637 = torch.constant.int 5 + %3152 = torch.prims.convert_element_type %3151, %int5_3637 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_3638 = torch.constant.int 0 + %int0_3639 = torch.constant.int 0 + %int1_3640 = torch.constant.int 1 + %3153 = torch.aten.slice.Tensor %3150, %int0_3638, %int0_3639, %298, %int1_3640 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3153, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_3641 = torch.constant.int 1 %int0_3642 = torch.constant.int 0 %int9223372036854775807_3643 = torch.constant.int 9223372036854775807 %int1_3644 = torch.constant.int 1 - %3139 = torch.aten.slice.Tensor %3138, %int1_3641, %int0_3642, %int9223372036854775807_3643, %int1_3644 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3139, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_3645 = torch.constant.int 2 + %3154 = torch.aten.slice.Tensor %3153, %int1_3641, %int0_3642, %int9223372036854775807_3643, %int1_3644 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3154, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_3645 = torch.constant.int 0 %int0_3646 = torch.constant.int 0 - %int9223372036854775807_3647 = torch.constant.int 9223372036854775807 + %int1_3647 = torch.constant.int 1 + %3155 = torch.aten.slice.Tensor %3152, %int0_3645, %int0_3646, %298, %int1_3647 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3155, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_3648 = torch.constant.int 1 - %3140 = torch.aten.slice.Tensor %3139, %int2_3645, %int0_3646, %int9223372036854775807_3647, %int1_3648 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3140, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_3649 = torch.constant.int 4 - %int1_3650 = torch.constant.int 1 + %int0_3649 = torch.constant.int 0 + %int9223372036854775807_3650 = torch.constant.int 9223372036854775807 %int1_3651 = torch.constant.int 1 - %3141 = torch.prim.ListConstruct %int4_3649, %int1_3650, %int1_3651 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3142 = torch.aten.repeat %3140, %3141 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %3142, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_3652 = torch.constant.int 6 - %3143 = torch.prims.convert_element_type %3090, %int6_3652 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %3143, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %3144 = torch_c.to_builtin_tensor %3143 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %3145 = torch_c.to_builtin_tensor %3142 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %3146 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%3144, %3145) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %3147 = torch_c.from_builtin_tensor %3146 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %3147, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_3653 = torch.constant.int 5 - %3148 = torch.prims.convert_element_type %3147, %int5_3653 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3148, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_3654 = torch.constant.int 64 - %3149 = torch.aten.mul.Scalar %arg2, %int64_3654 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3149, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int28 = torch.constant.int 28 - %int1_3655 = torch.constant.int 1 - %3150 = torch.aten.add.Scalar %3149, %int28, %int1_3655 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3150, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_3656 = torch.constant.int 4 - %int32_3657 = torch.constant.int 32 - %int8_3658 = torch.constant.int 8 - %int128_3659 = torch.constant.int 128 - %3151 = torch.prim.ListConstruct %int4_3656, %398, %int32_3657, %int8_3658, %int128_3659 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3152 = torch.aten.view %3148, %3151 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3152, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_3660 = torch.constant.int 4 - %3153 = torch.aten.mul.int %int4_3660, %398 : !torch.int, !torch.int -> !torch.int - %int32_3661 = torch.constant.int 32 - %int8_3662 = torch.constant.int 8 - %int128_3663 = torch.constant.int 128 - %3154 = torch.prim.ListConstruct %3153, %int32_3661, %int8_3662, %int128_3663 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3155 = torch.aten.view %3152, %3154 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3155, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_3664 = torch.constant.int 4 - %3156 = torch.aten.mul.int %int4_3664, %398 : !torch.int, !torch.int -> !torch.int - %3157 = torch.prim.ListConstruct %3156 : (!torch.int) -> !torch.list - %3158 = torch.aten.view %3150, %3157 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3158, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_3665 = torch.constant.int 32 - %int2_3666 = torch.constant.int 2 - %int32_3667 = torch.constant.int 32 - %int8_3668 = torch.constant.int 8 - %int128_3669 = torch.constant.int 128 - %3159 = torch.prim.ListConstruct %389, %int32_3665, %int2_3666, %int32_3667, %int8_3668, %int128_3669 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3160 = torch.aten.view %2992, %3159 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3160, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_3670 = torch.constant.int 32 - %3161 = torch.aten.mul.int %389, %int32_3670 : !torch.int, !torch.int -> !torch.int + %3156 = torch.aten.slice.Tensor %3155, %int1_3648, %int0_3649, %int9223372036854775807_3650, %int1_3651 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3156, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_3652 = torch.constant.int 0 + %3157 = torch.aten.unsqueeze %3154, %int0_3652 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3157, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_3653 = torch.constant.int 1 + %int0_3654 = torch.constant.int 0 + %int9223372036854775807_3655 = torch.constant.int 9223372036854775807 + %int1_3656 = torch.constant.int 1 + %3158 = torch.aten.slice.Tensor %3157, %int1_3653, %int0_3654, %int9223372036854775807_3655, %int1_3656 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3158, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_3657 = torch.constant.int 2 + %3159 = torch.aten.unsqueeze %3158, %int2_3657 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3159, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_3658 = torch.constant.int 3 + %int0_3659 = torch.constant.int 0 + %int9223372036854775807_3660 = torch.constant.int 9223372036854775807 + %int1_3661 = torch.constant.int 1 + %3160 = torch.aten.slice.Tensor %3159, %int3_3658, %int0_3659, %int9223372036854775807_3660, %int1_3661 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3160, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_3662 = torch.constant.int 4 + %int1_3663 = torch.constant.int 1 + %int1_3664 = torch.constant.int 1 + %int1_3665 = torch.constant.int 1 + %3161 = torch.prim.ListConstruct %int4_3662, %int1_3663, %int1_3664, %int1_3665 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3162 = torch.aten.repeat %3160, %3161 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3162, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_3666 = torch.constant.int 0 + %3163 = torch.aten.unsqueeze %3156, %int0_3666 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3163, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_3667 = torch.constant.int 1 + %int0_3668 = torch.constant.int 0 + %int9223372036854775807_3669 = torch.constant.int 9223372036854775807 + %int1_3670 = torch.constant.int 1 + %3164 = torch.aten.slice.Tensor %3163, %int1_3667, %int0_3668, %int9223372036854775807_3669, %int1_3670 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3164, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int2_3671 = torch.constant.int 2 - %3162 = torch.aten.mul.int %3161, %int2_3671 : !torch.int, !torch.int -> !torch.int - %int32_3672 = torch.constant.int 32 - %int8_3673 = torch.constant.int 8 - %int128_3674 = torch.constant.int 128 - %3163 = torch.prim.ListConstruct %3162, %int32_3672, %int8_3673, %int128_3674 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3164 = torch.aten.view %3160, %3163 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3164, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %3165 = torch.prim.ListConstruct %3158 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_3675 = torch.constant.bool false - %3166 = torch.aten.index_put %3164, %3165, %3155, %false_3675 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3166, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_3676 = torch.constant.int 32 - %int2_3677 = torch.constant.int 2 - %int32_3678 = torch.constant.int 32 - %int8_3679 = torch.constant.int 8 - %int128_3680 = torch.constant.int 128 - %3167 = torch.prim.ListConstruct %389, %int32_3676, %int2_3677, %int32_3678, %int8_3679, %int128_3680 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3168 = torch.aten.view %3166, %3167 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3168, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3681 = torch.constant.int 2097152 - %3169 = torch.prim.ListConstruct %389, %int2097152_3681 : (!torch.int, !torch.int) -> !torch.list - %3170 = torch.aten.view %3168, %3169 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3170, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_3682 = torch.constant.int 32 - %int2_3683 = torch.constant.int 2 - %int32_3684 = torch.constant.int 32 - %int8_3685 = torch.constant.int 8 - %int128_3686 = torch.constant.int 128 - %3171 = torch.prim.ListConstruct %389, %int32_3682, %int2_3683, %int32_3684, %int8_3685, %int128_3686 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3172 = torch.aten.view %3170, %3171 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3172, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_3687 = torch.constant.int 32 - %int8_3688 = torch.constant.int 8 - %int128_3689 = torch.constant.int 128 - %3173 = torch.prim.ListConstruct %3162, %int32_3687, %int8_3688, %int128_3689 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3174 = torch.aten.view %3172, %3173 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3174, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_3690 = torch.constant.int 4 - %int32_3691 = torch.constant.int 32 - %int8_3692 = torch.constant.int 8 - %int128_3693 = torch.constant.int 128 - %3175 = torch.prim.ListConstruct %int4_3690, %398, %int32_3691, %int8_3692, %int128_3693 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3176 = torch.aten.view %3092, %3175 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3176, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_3694 = torch.constant.int 4 - %3177 = torch.aten.mul.int %int4_3694, %398 : !torch.int, !torch.int -> !torch.int - %int32_3695 = torch.constant.int 32 - %int8_3696 = torch.constant.int 8 - %int128_3697 = torch.constant.int 128 - %3178 = torch.prim.ListConstruct %3177, %int32_3695, %int8_3696, %int128_3697 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3179 = torch.aten.view %3176, %3178 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3179, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_3698 = torch.constant.int 1 - %int1_3699 = torch.constant.int 1 - %3180 = torch.aten.add.Scalar %3150, %int1_3698, %int1_3699 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3180, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_3700 = torch.constant.int 4 - %3181 = torch.aten.mul.int %int4_3700, %398 : !torch.int, !torch.int -> !torch.int - %3182 = torch.prim.ListConstruct %3181 : (!torch.int) -> !torch.list - %3183 = torch.aten.view %3180, %3182 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3183, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %3184 = torch.prim.ListConstruct %3183 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_3701 = torch.constant.bool false - %3185 = torch.aten.index_put %3174, %3184, %3179, %false_3701 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3185, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_3702 = torch.constant.int 32 + %3165 = torch.aten.unsqueeze %3164, %int2_3671 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3165, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_3672 = torch.constant.int 3 + %int0_3673 = torch.constant.int 0 + %int9223372036854775807_3674 = torch.constant.int 9223372036854775807 + %int1_3675 = torch.constant.int 1 + %3166 = torch.aten.slice.Tensor %3165, %int3_3672, %int0_3673, %int9223372036854775807_3674, %int1_3675 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3166, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_3676 = torch.constant.int 4 + %int1_3677 = torch.constant.int 1 + %int1_3678 = torch.constant.int 1 + %int1_3679 = torch.constant.int 1 + %3167 = torch.prim.ListConstruct %int4_3676, %int1_3677, %int1_3678, %int1_3679 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3168 = torch.aten.repeat %3166, %3167 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3168, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %3169 = torch.aten.mul.Tensor %3048, %3162 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3169, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_3680 = torch.constant.int 3 + %int0_3681 = torch.constant.int 0 + %int64_3682 = torch.constant.int 64 + %int1_3683 = torch.constant.int 1 + %3170 = torch.aten.slice.Tensor %3048, %int3_3680, %int0_3681, %int64_3682, %int1_3683 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %3170, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_3684 = torch.constant.int 3 + %int64_3685 = torch.constant.int 64 + %int9223372036854775807_3686 = torch.constant.int 9223372036854775807 + %int1_3687 = torch.constant.int 1 + %3171 = torch.aten.slice.Tensor %3048, %int3_3684, %int64_3685, %int9223372036854775807_3686, %int1_3687 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %3171, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %3172 = torch.aten.neg %3171 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %3172, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %3173 = torch.prim.ListConstruct %3172, %3170 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_3688 = torch.constant.int -1 + %3174 = torch.aten.cat %3173, %int-1_3688 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3174, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %3175 = torch.aten.mul.Tensor %3174, %3168 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3175, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_3689 = torch.constant.int 1 + %3176 = torch.aten.add.Tensor %3169, %3175, %int1_3689 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3176, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_3690 = torch.constant.int 32 + %3177 = torch.aten.mul.Scalar %arg2, %int32_3690 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3177, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int10 = torch.constant.int 10 + %int1_3691 = torch.constant.int 1 + %3178 = torch.aten.add.Scalar %3177, %int10, %int1_3691 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3178, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_3692 = torch.constant.int 2 + %3179 = torch.aten.mul.Scalar %3178, %int2_3692 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3179, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_3693 = torch.constant.int 0 + %int1_3694 = torch.constant.int 1 + %3180 = torch.aten.add.Scalar %3179, %int0_3693, %int1_3694 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3180, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %3181 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %3182 = torch.aten.view %3180, %3181 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %3182, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_3695 = torch.constant.int 4 + %int32_3696 = torch.constant.int 32 + %int8_3697 = torch.constant.int 8 + %int128_3698 = torch.constant.int 128 + %3183 = torch.prim.ListConstruct %int4_3695, %296, %int32_3696, %int8_3697, %int128_3698 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3184 = torch.aten.view %3176, %3183 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3184, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_3699 = torch.constant.int 32 + %int8_3700 = torch.constant.int 8 + %int128_3701 = torch.constant.int 128 + %3185 = torch.prim.ListConstruct %504, %int32_3699, %int8_3700, %int128_3701 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3186 = torch.aten.view %3184, %3185 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %3186, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_3702 = torch.constant.int 1 %int2_3703 = torch.constant.int 2 - %int32_3704 = torch.constant.int 32 - %int8_3705 = torch.constant.int 8 - %int128_3706 = torch.constant.int 128 - %3186 = torch.prim.ListConstruct %389, %int32_3702, %int2_3703, %int32_3704, %int8_3705, %int128_3706 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3187 = torch.aten.view %3185, %3186 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3187, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3707 = torch.constant.int 2097152 - %3188 = torch.prim.ListConstruct %389, %int2097152_3707 : (!torch.int, !torch.int) -> !torch.list - %3189 = torch.aten.view %3187, %3188 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3189, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_3708 = torch.constant.int -2 - %3190 = torch.aten.unsqueeze %3148, %int-2_3708 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3190, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_3709 = torch.constant.int 4 + %3187 = torch.aten.transpose.int %3186, %int1_3702, %int2_3703 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3187, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_3704 = torch.constant.int 5 + %3188 = torch.prims.convert_element_type %3187, %int5_3704 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3188, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_3705 = torch.constant.int 32 + %int2_3706 = torch.constant.int 2 + %int8_3707 = torch.constant.int 8 + %int32_3708 = torch.constant.int 32 + %int128_3709 = torch.constant.int 128 + %3189 = torch.prim.ListConstruct %297, %int32_3705, %int2_3706, %int8_3707, %int32_3708, %int128_3709 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3190 = torch.aten.view %2952, %3189 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3190, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> %int8_3710 = torch.constant.int 8 - %int4_3711 = torch.constant.int 4 + %int32_3711 = torch.constant.int 32 %int128_3712 = torch.constant.int 128 - %3191 = torch.prim.ListConstruct %int4_3709, %3133, %int8_3710, %int4_3711, %int128_3712 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3191 = torch.prim.ListConstruct %497, %int8_3710, %int32_3711, %int128_3712 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3192 = torch.aten.view %3190, %3191 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3192, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %3193 = torch.prim.ListConstruct %3182 : (!torch.vtensor<[?],si64>) -> !torch.list> %false_3713 = torch.constant.bool false - %3192 = torch.aten.expand %3190, %3191, %false_3713 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3192, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_3714 = torch.constant.int 0 - %3193 = torch.aten.clone %3192, %int0_3714 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3193, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_3715 = torch.constant.int 4 - %int32_3716 = torch.constant.int 32 - %int128_3717 = torch.constant.int 128 - %3194 = torch.prim.ListConstruct %int4_3715, %3133, %int32_3716, %int128_3717 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3195 = torch.aten._unsafe_view %3193, %3194 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3195, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_3718 = torch.constant.int -2 - %3196 = torch.aten.unsqueeze %3092, %int-2_3718 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3196, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_3719 = torch.constant.int 1 - %3197 = torch.aten.size.int %3086, %int1_3719 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_3720 = torch.constant.int 4 - %int8_3721 = torch.constant.int 8 - %int4_3722 = torch.constant.int 4 - %int128_3723 = torch.constant.int 128 - %3198 = torch.prim.ListConstruct %int4_3720, %3197, %int8_3721, %int4_3722, %int128_3723 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_3724 = torch.constant.bool false - %3199 = torch.aten.expand %3196, %3198, %false_3724 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3199, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_3725 = torch.constant.int 0 - %3200 = torch.aten.clone %3199, %int0_3725 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3200, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_3726 = torch.constant.int 4 - %int32_3727 = torch.constant.int 32 - %int128_3728 = torch.constant.int 128 - %3201 = torch.prim.ListConstruct %int4_3726, %3197, %int32_3727, %int128_3728 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3202 = torch.aten._unsafe_view %3200, %3201 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3202, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_3729 = torch.constant.int 1 - %int2_3730 = torch.constant.int 2 - %3203 = torch.aten.transpose.int %3120, %int1_3729, %int2_3730 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3203, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_3731 = torch.constant.int 1 - %int2_3732 = torch.constant.int 2 - %3204 = torch.aten.transpose.int %3195, %int1_3731, %int2_3732 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3204, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %3194 = torch.aten.index_put %3192, %3193, %3188, %false_3713 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3194, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_3714 = torch.constant.int 32 + %int2_3715 = torch.constant.int 2 + %int8_3716 = torch.constant.int 8 + %int32_3717 = torch.constant.int 32 + %int128_3718 = torch.constant.int 128 + %3195 = torch.prim.ListConstruct %297, %int32_3714, %int2_3715, %int8_3716, %int32_3717, %int128_3718 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3196 = torch.aten.view %3194, %3195 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3196, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_3719 = torch.constant.int 2097152 + %3197 = torch.prim.ListConstruct %297, %int2097152_3719 : (!torch.int, !torch.int) -> !torch.list + %3198 = torch.aten.view %3196, %3197 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %3198, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_3720 = torch.constant.int 32 + %int2_3721 = torch.constant.int 2 + %int8_3722 = torch.constant.int 8 + %int32_3723 = torch.constant.int 32 + %int128_3724 = torch.constant.int 128 + %3199 = torch.prim.ListConstruct %297, %int32_3720, %int2_3721, %int8_3722, %int32_3723, %int128_3724 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3200 = torch.aten.view %3198, %3199 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3200, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_3725 = torch.constant.int 8 + %int32_3726 = torch.constant.int 32 + %int128_3727 = torch.constant.int 128 + %3201 = torch.prim.ListConstruct %497, %int8_3725, %int32_3726, %int128_3727 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3202 = torch.aten.view %3200, %3201 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3202, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_3728 = torch.constant.int 32 + %3203 = torch.aten.mul.Scalar %arg2, %int32_3728 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3203, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int10_3729 = torch.constant.int 10 + %int1_3730 = torch.constant.int 1 + %3204 = torch.aten.add.Scalar %3203, %int10_3729, %int1_3730 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3204, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_3731 = torch.constant.int 2 + %3205 = torch.aten.mul.Scalar %3204, %int2_3731 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3205, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_3732 = torch.constant.int 1 %int1_3733 = torch.constant.int 1 - %int2_3734 = torch.constant.int 2 - %3205 = torch.aten.transpose.int %3202, %int1_3733, %int2_3734 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3205, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_3735 = torch.constant.float 0.000000e+00 - %true_3736 = torch.constant.bool true - %none_3737 = torch.constant.none - %none_3738 = torch.constant.none - %3206:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3203, %3204, %3205, %float0.000000e00_3735, %true_3736, %none_3737, %none_3738) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %3206#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_3739 = torch.constant.int 1 - %int2_3740 = torch.constant.int 2 - %3207 = torch.aten.transpose.int %3206#0, %int1_3739, %int2_3740 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3207, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_3741 = torch.constant.int 4 - %int4096_3742 = torch.constant.int 4096 - %3208 = torch.prim.ListConstruct %int4_3741, %3105, %int4096_3742 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3209 = torch.aten.view %3207, %3208 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3209, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3743 = torch.constant.int -2 - %int-1_3744 = torch.constant.int -1 - %3210 = torch.aten.transpose.int %131, %int-2_3743, %int-1_3744 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3745 = torch.constant.int 4 - %3211 = torch.aten.mul.int %int4_3745, %3105 : !torch.int, !torch.int -> !torch.int - %int4096_3746 = torch.constant.int 4096 - %3212 = torch.prim.ListConstruct %3211, %int4096_3746 : (!torch.int, !torch.int) -> !torch.list - %3213 = torch.aten.view %3209, %3212 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3213, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3214 = torch.aten.mm %3213, %3210 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3214, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_3747 = torch.constant.int 4 - %int4096_3748 = torch.constant.int 4096 - %3215 = torch.prim.ListConstruct %int4_3747, %3105, %int4096_3748 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3216 = torch.aten.view %3214, %3215 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3216, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_3749 = torch.constant.int 1 - %3217 = torch.aten.add.Tensor %3055, %3216, %int1_3749 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3217, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_3750 = torch.constant.int 6 - %3218 = torch.prims.convert_element_type %3217, %int6_3750 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3218, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_3751 = torch.constant.int 2 - %3219 = torch.aten.pow.Tensor_Scalar %3218, %int2_3751 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3219, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_3752 = torch.constant.int -1 - %3220 = torch.prim.ListConstruct %int-1_3752 : (!torch.int) -> !torch.list - %true_3753 = torch.constant.bool true - %none_3754 = torch.constant.none - %3221 = torch.aten.mean.dim %3219, %3220, %true_3753, %none_3754 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3221, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_3755 = torch.constant.float 9.9999997473787516E-6 - %int1_3756 = torch.constant.int 1 - %3222 = torch.aten.add.Scalar %3221, %float9.999990e-06_3755, %int1_3756 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3222, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3223 = torch.aten.rsqrt %3222 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3223, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3224 = torch.aten.mul.Tensor %3218, %3223 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3224, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_3757 = torch.constant.int 5 - %3225 = torch.prims.convert_element_type %3224, %int5_3757 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3225, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %3226 = torch.aten.mul.Tensor %132, %3225 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3226, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_3758 = torch.constant.int 5 - %3227 = torch.prims.convert_element_type %3226, %int5_3758 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3227, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3759 = torch.constant.int -2 - %int-1_3760 = torch.constant.int -1 - %3228 = torch.aten.transpose.int %133, %int-2_3759, %int-1_3760 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3761 = torch.constant.int 4 - %3229 = torch.aten.mul.int %int4_3761, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3762 = torch.constant.int 4096 - %3230 = torch.prim.ListConstruct %3229, %int4096_3762 : (!torch.int, !torch.int) -> !torch.list - %3231 = torch.aten.view %3227, %3230 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3231, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3232 = torch.aten.mm %3231, %3228 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3232, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_3763 = torch.constant.int 4 - %int14336_3764 = torch.constant.int 14336 - %3233 = torch.prim.ListConstruct %int4_3763, %306, %int14336_3764 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3234 = torch.aten.view %3232, %3233 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3234, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %3235 = torch.aten.silu %3234 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3235, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_3765 = torch.constant.int -2 - %int-1_3766 = torch.constant.int -1 - %3236 = torch.aten.transpose.int %134, %int-2_3765, %int-1_3766 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3767 = torch.constant.int 4 - %3237 = torch.aten.mul.int %int4_3767, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3768 = torch.constant.int 4096 - %3238 = torch.prim.ListConstruct %3237, %int4096_3768 : (!torch.int, !torch.int) -> !torch.list - %3239 = torch.aten.view %3227, %3238 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3239, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3240 = torch.aten.mm %3239, %3236 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3240, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_3769 = torch.constant.int 4 - %int14336_3770 = torch.constant.int 14336 - %3241 = torch.prim.ListConstruct %int4_3769, %306, %int14336_3770 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3242 = torch.aten.view %3240, %3241 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3242, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %3243 = torch.aten.mul.Tensor %3235, %3242 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3243, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_3771 = torch.constant.int -2 - %int-1_3772 = torch.constant.int -1 - %3244 = torch.aten.transpose.int %135, %int-2_3771, %int-1_3772 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %3206 = torch.aten.add.Scalar %3205, %int1_3732, %int1_3733 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3206, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %3207 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %3208 = torch.aten.view %3206, %3207 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %3208, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_3734 = torch.constant.int 4 + %int32_3735 = torch.constant.int 32 + %int8_3736 = torch.constant.int 8 + %int128_3737 = torch.constant.int 128 + %3209 = torch.prim.ListConstruct %int4_3734, %296, %int32_3735, %int8_3736, %int128_3737 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3210 = torch.aten.view %3050, %3209 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3210, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_3738 = torch.constant.int 32 + %int8_3739 = torch.constant.int 8 + %int128_3740 = torch.constant.int 128 + %3211 = torch.prim.ListConstruct %504, %int32_3738, %int8_3739, %int128_3740 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3212 = torch.aten.view %3210, %3211 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %3212, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_3741 = torch.constant.int 1 + %int2_3742 = torch.constant.int 2 + %3213 = torch.aten.transpose.int %3212, %int1_3741, %int2_3742 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3213, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_3743 = torch.constant.int 5 + %3214 = torch.prims.convert_element_type %3213, %int5_3743 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3214, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %3215 = torch.prim.ListConstruct %3208 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_3744 = torch.constant.bool false + %3216 = torch.aten.index_put %3202, %3215, %3214, %false_3744 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3216, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_3745 = torch.constant.int 32 + %int2_3746 = torch.constant.int 2 + %int8_3747 = torch.constant.int 8 + %int32_3748 = torch.constant.int 32 + %int128_3749 = torch.constant.int 128 + %3217 = torch.prim.ListConstruct %297, %int32_3745, %int2_3746, %int8_3747, %int32_3748, %int128_3749 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3218 = torch.aten.view %3216, %3217 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3218, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_3750 = torch.constant.int 2097152 + %3219 = torch.prim.ListConstruct %297, %int2097152_3750 : (!torch.int, !torch.int) -> !torch.list + %3220 = torch.aten.view %3218, %3219 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %3220, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_3751 = torch.constant.int -2 + %3221 = torch.aten.unsqueeze %3176, %int-2_3751 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3221, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_3752 = torch.constant.int 4 + %int8_3753 = torch.constant.int 8 + %int4_3754 = torch.constant.int 4 + %int128_3755 = torch.constant.int 128 + %3222 = torch.prim.ListConstruct %int4_3752, %298, %int8_3753, %int4_3754, %int128_3755 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_3756 = torch.constant.bool false + %3223 = torch.aten.expand %3221, %3222, %false_3756 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3223, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_3757 = torch.constant.int 0 + %3224 = torch.aten.clone %3223, %int0_3757 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3224, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_3758 = torch.constant.int 4 + %int32_3759 = torch.constant.int 32 + %int128_3760 = torch.constant.int 128 + %3225 = torch.prim.ListConstruct %int4_3758, %298, %int32_3759, %int128_3760 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3226 = torch.aten._unsafe_view %3224, %3225 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3226, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_3761 = torch.constant.int -2 + %3227 = torch.aten.unsqueeze %3050, %int-2_3761 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3227, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_3762 = torch.constant.int 4 + %int8_3763 = torch.constant.int 8 + %int4_3764 = torch.constant.int 4 + %int128_3765 = torch.constant.int 128 + %3228 = torch.prim.ListConstruct %int4_3762, %298, %int8_3763, %int4_3764, %int128_3765 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_3766 = torch.constant.bool false + %3229 = torch.aten.expand %3227, %3228, %false_3766 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3229, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_3767 = torch.constant.int 0 + %3230 = torch.aten.clone %3229, %int0_3767 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3230, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_3768 = torch.constant.int 4 + %int32_3769 = torch.constant.int 32 + %int128_3770 = torch.constant.int 128 + %3231 = torch.prim.ListConstruct %int4_3768, %298, %int32_3769, %int128_3770 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3232 = torch.aten._unsafe_view %3230, %3231 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3232, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_3771 = torch.constant.int 1 + %int2_3772 = torch.constant.int 2 + %3233 = torch.aten.transpose.int %3113, %int1_3771, %int2_3772 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3233, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_3773 = torch.constant.int 1 - %3245 = torch.aten.size.int %3234, %int1_3773 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_3774 = torch.constant.int 4 - %3246 = torch.aten.mul.int %int4_3774, %3245 : !torch.int, !torch.int -> !torch.int - %int14336_3775 = torch.constant.int 14336 - %3247 = torch.prim.ListConstruct %3246, %int14336_3775 : (!torch.int, !torch.int) -> !torch.list - %3248 = torch.aten.view %3243, %3247 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3248, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %3249 = torch.aten.mm %3248, %3244 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3249, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_3776 = torch.constant.int 4 - %int4096_3777 = torch.constant.int 4096 - %3250 = torch.prim.ListConstruct %int4_3776, %3245, %int4096_3777 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3251 = torch.aten.view %3249, %3250 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3251, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_3778 = torch.constant.int 1 - %3252 = torch.aten.add.Tensor %3217, %3251, %int1_3778 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3252, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_3779 = torch.constant.int 6 - %3253 = torch.prims.convert_element_type %3252, %int6_3779 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3253, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_3780 = torch.constant.int 2 - %3254 = torch.aten.pow.Tensor_Scalar %3253, %int2_3780 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3254, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_3781 = torch.constant.int -1 - %3255 = torch.prim.ListConstruct %int-1_3781 : (!torch.int) -> !torch.list - %true_3782 = torch.constant.bool true - %none_3783 = torch.constant.none - %3256 = torch.aten.mean.dim %3254, %3255, %true_3782, %none_3783 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3256, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_3784 = torch.constant.float 9.9999997473787516E-6 - %int1_3785 = torch.constant.int 1 - %3257 = torch.aten.add.Scalar %3256, %float9.999990e-06_3784, %int1_3785 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3257, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3258 = torch.aten.rsqrt %3257 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3258, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3259 = torch.aten.mul.Tensor %3253, %3258 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3259, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_3774 = torch.constant.int 2 + %3234 = torch.aten.transpose.int %3226, %int1_3773, %int2_3774 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3234, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_3775 = torch.constant.int 1 + %int2_3776 = torch.constant.int 2 + %3235 = torch.aten.transpose.int %3232, %int1_3775, %int2_3776 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3235, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_3777 = torch.constant.float 0.000000e+00 + %false_3778 = torch.constant.bool false + %none_3779 = torch.constant.none + %3236:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3233, %3234, %3235, %float0.000000e00_3777, %false_3778, %327, %none_3779) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %3236#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_3780 = torch.constant.int 1 + %int2_3781 = torch.constant.int 2 + %3237 = torch.aten.transpose.int %3236#0, %int1_3780, %int2_3781 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3237, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_3782 = torch.constant.int 4 + %int4096_3783 = torch.constant.int 4096 + %3238 = torch.prim.ListConstruct %int4_3782, %298, %int4096_3783 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3239 = torch.aten.view %3237, %3238 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3239, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_3784 = torch.constant.int -2 + %int-1_3785 = torch.constant.int -1 + %3240 = torch.aten.transpose.int %96, %int-2_3784, %int-1_3785 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> %int5_3786 = torch.constant.int 5 - %3260 = torch.prims.convert_element_type %3259, %int5_3786 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3260, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %3261 = torch.aten.mul.Tensor %136, %3260 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3261, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_3787 = torch.constant.int 5 - %3262 = torch.prims.convert_element_type %3261, %int5_3787 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3262, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3788 = torch.constant.int -2 - %int-1_3789 = torch.constant.int -1 - %3263 = torch.aten.transpose.int %137, %int-2_3788, %int-1_3789 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3790 = torch.constant.int 4 - %3264 = torch.aten.mul.int %int4_3790, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3791 = torch.constant.int 4096 - %3265 = torch.prim.ListConstruct %3264, %int4096_3791 : (!torch.int, !torch.int) -> !torch.list - %3266 = torch.aten.view %3262, %3265 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3266, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3267 = torch.aten.mm %3266, %3263 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3267, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_3792 = torch.constant.int 4 - %int4096_3793 = torch.constant.int 4096 - %3268 = torch.prim.ListConstruct %int4_3792, %306, %int4096_3793 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3269 = torch.aten.view %3267, %3268 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3269, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3794 = torch.constant.int -2 - %int-1_3795 = torch.constant.int -1 - %3270 = torch.aten.transpose.int %138, %int-2_3794, %int-1_3795 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3796 = torch.constant.int 4 - %3271 = torch.aten.mul.int %int4_3796, %306 : !torch.int, !torch.int -> !torch.int - %int4096_3797 = torch.constant.int 4096 - %3272 = torch.prim.ListConstruct %3271, %int4096_3797 : (!torch.int, !torch.int) -> !torch.list - %3273 = torch.aten.view %3262, %3272 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3273, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3274 = torch.aten.mm %3273, %3270 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %3274, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_3798 = torch.constant.int 4 - %int1024_3799 = torch.constant.int 1024 - %3275 = torch.prim.ListConstruct %int4_3798, %306, %int1024_3799 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3276 = torch.aten.view %3274, %3275 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %3276, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %3241 = torch.prims.convert_element_type %3240, %int5_3786 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_3787 = torch.constant.int 4096 + %3242 = torch.prim.ListConstruct %342, %int4096_3787 : (!torch.int, !torch.int) -> !torch.list + %3243 = torch.aten.view %3239, %3242 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3243, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3244 = torch.aten.mm %3243, %3241 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3244, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_3788 = torch.constant.int 4 + %int4096_3789 = torch.constant.int 4096 + %3245 = torch.prim.ListConstruct %int4_3788, %298, %int4096_3789 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3246 = torch.aten.view %3244, %3245 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3246, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_3790 = torch.constant.int 1 + %3247 = torch.aten.add.Tensor %3013, %3246, %int1_3790 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3247, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_3791 = torch.constant.int 6 + %3248 = torch.prims.convert_element_type %3247, %int6_3791 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3248, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_3792 = torch.constant.int 2 + %3249 = torch.aten.pow.Tensor_Scalar %3248, %int2_3792 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3249, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_3793 = torch.constant.int -1 + %3250 = torch.prim.ListConstruct %int-1_3793 : (!torch.int) -> !torch.list + %true_3794 = torch.constant.bool true + %none_3795 = torch.constant.none + %3251 = torch.aten.mean.dim %3249, %3250, %true_3794, %none_3795 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3251, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_3796 = torch.constant.float 9.9999997473787516E-6 + %int1_3797 = torch.constant.int 1 + %3252 = torch.aten.add.Scalar %3251, %float9.999990e-06_3796, %int1_3797 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3252, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %3253 = torch.aten.rsqrt %3252 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3253, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %3254 = torch.aten.mul.Tensor %3248, %3253 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3254, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_3798 = torch.constant.int 5 + %3255 = torch.prims.convert_element_type %3254, %int5_3798 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3255, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %3256 = torch.aten.mul.Tensor %97, %3255 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3256, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_3799 = torch.constant.int 5 + %3257 = torch.prims.convert_element_type %3256, %int5_3799 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3257, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> %int-2_3800 = torch.constant.int -2 %int-1_3801 = torch.constant.int -1 - %3277 = torch.aten.transpose.int %139, %int-2_3800, %int-1_3801 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3802 = torch.constant.int 4 - %3278 = torch.aten.mul.int %int4_3802, %306 : !torch.int, !torch.int -> !torch.int + %3258 = torch.aten.transpose.int %98, %int-2_3800, %int-1_3801 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3802 = torch.constant.int 5 + %3259 = torch.prims.convert_element_type %3258, %int5_3802 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> %int4096_3803 = torch.constant.int 4096 - %3279 = torch.prim.ListConstruct %3278, %int4096_3803 : (!torch.int, !torch.int) -> !torch.list - %3280 = torch.aten.view %3262, %3279 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3280, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3281 = torch.aten.mm %3280, %3277 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %3281, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %3260 = torch.prim.ListConstruct %342, %int4096_3803 : (!torch.int, !torch.int) -> !torch.list + %3261 = torch.aten.view %3257, %3260 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3261, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3262 = torch.aten.mm %3261, %3259 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %3262, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> %int4_3804 = torch.constant.int 4 - %int1024_3805 = torch.constant.int 1024 - %3282 = torch.prim.ListConstruct %int4_3804, %306, %int1024_3805 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3283 = torch.aten.view %3281, %3282 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %3283, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_3806 = torch.constant.int 4 - %int32_3807 = torch.constant.int 32 - %int128_3808 = torch.constant.int 128 - %3284 = torch.prim.ListConstruct %int4_3806, %306, %int32_3807, %int128_3808 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3285 = torch.aten.view %3269, %3284 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3285, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_3809 = torch.constant.int 4 - %int8_3810 = torch.constant.int 8 - %int128_3811 = torch.constant.int 128 - %3286 = torch.prim.ListConstruct %int4_3809, %306, %int8_3810, %int128_3811 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3287 = torch.aten.view %3276, %3286 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3287, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_3812 = torch.constant.int 4 - %int8_3813 = torch.constant.int 8 - %int128_3814 = torch.constant.int 128 - %3288 = torch.prim.ListConstruct %int4_3812, %306, %int8_3813, %int128_3814 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3289 = torch.aten.view %3283, %3288 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3289, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_3815 = torch.constant.int 131072 - %none_3816 = torch.constant.none - %none_3817 = torch.constant.none - %cpu_3818 = torch.constant.device "cpu" - %false_3819 = torch.constant.bool false - %3290 = torch.aten.arange %int131072_3815, %none_3816, %none_3817, %cpu_3818, %false_3819 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_3820 = torch.constant.int 0 - %int128_3821 = torch.constant.int 128 - %none_3822 = torch.constant.none + %int14336_3805 = torch.constant.int 14336 + %3263 = torch.prim.ListConstruct %int4_3804, %298, %int14336_3805 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3264 = torch.aten.view %3262, %3263 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %3264, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %3265 = torch.aten.silu %3264 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %3265, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_3806 = torch.constant.int -2 + %int-1_3807 = torch.constant.int -1 + %3266 = torch.aten.transpose.int %99, %int-2_3806, %int-1_3807 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3808 = torch.constant.int 5 + %3267 = torch.prims.convert_element_type %3266, %int5_3808 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_3809 = torch.constant.int 4096 + %3268 = torch.prim.ListConstruct %342, %int4096_3809 : (!torch.int, !torch.int) -> !torch.list + %3269 = torch.aten.view %3257, %3268 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3269, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3270 = torch.aten.mm %3269, %3267 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %3270, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_3810 = torch.constant.int 4 + %int14336_3811 = torch.constant.int 14336 + %3271 = torch.prim.ListConstruct %int4_3810, %298, %int14336_3811 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3272 = torch.aten.view %3270, %3271 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %3272, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %3273 = torch.aten.mul.Tensor %3265, %3272 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %3273, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_3812 = torch.constant.int -2 + %int-1_3813 = torch.constant.int -1 + %3274 = torch.aten.transpose.int %100, %int-2_3812, %int-1_3813 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_3814 = torch.constant.int 5 + %3275 = torch.prims.convert_element_type %3274, %int5_3814 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_3815 = torch.constant.int 14336 + %3276 = torch.prim.ListConstruct %342, %int14336_3815 : (!torch.int, !torch.int) -> !torch.list + %3277 = torch.aten.view %3273, %3276 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %3277, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %3278 = torch.aten.mm %3277, %3275 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3278, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_3816 = torch.constant.int 4 + %int4096_3817 = torch.constant.int 4096 + %3279 = torch.prim.ListConstruct %int4_3816, %298, %int4096_3817 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3280 = torch.aten.view %3278, %3279 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3280, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_3818 = torch.constant.int 1 + %3281 = torch.aten.add.Tensor %3247, %3280, %int1_3818 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3281, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_3819 = torch.constant.int 6 + %3282 = torch.prims.convert_element_type %3281, %int6_3819 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3282, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_3820 = torch.constant.int 2 + %3283 = torch.aten.pow.Tensor_Scalar %3282, %int2_3820 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3283, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_3821 = torch.constant.int -1 + %3284 = torch.prim.ListConstruct %int-1_3821 : (!torch.int) -> !torch.list + %true_3822 = torch.constant.bool true %none_3823 = torch.constant.none - %cpu_3824 = torch.constant.device "cpu" - %false_3825 = torch.constant.bool false - %3291 = torch.aten.arange.start %int0_3820, %int128_3821, %none_3822, %none_3823, %cpu_3824, %false_3825 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_3826 = torch.constant.int 2 - %3292 = torch.aten.floor_divide.Scalar %3291, %int2_3826 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_3827 = torch.constant.int 6 - %3293 = torch.prims.convert_element_type %3292, %int6_3827 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_3828 = torch.constant.int 128 - %3294 = torch.aten.div.Scalar %3293, %int128_3828 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_3829 = torch.constant.float 2.000000e+00 - %3295 = torch.aten.mul.Scalar %3294, %float2.000000e00_3829 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_3830 = torch.constant.float 5.000000e+05 - %3296 = torch.aten.pow.Scalar %float5.000000e05_3830, %3295 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %3297 = torch.aten.reciprocal %3296 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_3831 = torch.constant.float 1.000000e+00 - %3298 = torch.aten.mul.Scalar %3297, %float1.000000e00_3831 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_3832 = torch.constant.int 1 - %3299 = torch.aten.unsqueeze %3290, %int1_3832 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_3833 = torch.constant.int 0 - %3300 = torch.aten.unsqueeze %3298, %int0_3833 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %3301 = torch.aten.mul.Tensor %3299, %3300 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_3834 = torch.constant.int 1 - %3302 = torch.aten.size.int %3269, %int1_3834 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_3835 = torch.constant.int 0 - %3303 = torch.aten.add.int %int0_3835, %3302 : !torch.int, !torch.int -> !torch.int - %int0_3836 = torch.constant.int 0 - %int0_3837 = torch.constant.int 0 - %int1_3838 = torch.constant.int 1 - %3304 = torch.aten.slice.Tensor %3301, %int0_3836, %int0_3837, %3303, %int1_3838 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3304, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_3839 = torch.constant.int 1 - %int0_3840 = torch.constant.int 0 - %int9223372036854775807_3841 = torch.constant.int 9223372036854775807 - %int1_3842 = torch.constant.int 1 - %3305 = torch.aten.slice.Tensor %3304, %int1_3839, %int0_3840, %int9223372036854775807_3841, %int1_3842 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3305, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_3843 = torch.constant.int 1 - %int0_3844 = torch.constant.int 0 - %int9223372036854775807_3845 = torch.constant.int 9223372036854775807 - %int1_3846 = torch.constant.int 1 - %3306 = torch.aten.slice.Tensor %3305, %int1_3843, %int0_3844, %int9223372036854775807_3845, %int1_3846 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3306, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_3847 = torch.constant.int 0 - %3307 = torch.aten.unsqueeze %3306, %int0_3847 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3307, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_3848 = torch.constant.int 1 - %int0_3849 = torch.constant.int 0 - %int9223372036854775807_3850 = torch.constant.int 9223372036854775807 - %int1_3851 = torch.constant.int 1 - %3308 = torch.aten.slice.Tensor %3307, %int1_3848, %int0_3849, %int9223372036854775807_3850, %int1_3851 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3308, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_3852 = torch.constant.int 2 - %int0_3853 = torch.constant.int 0 - %int9223372036854775807_3854 = torch.constant.int 9223372036854775807 - %int1_3855 = torch.constant.int 1 - %3309 = torch.aten.slice.Tensor %3308, %int2_3852, %int0_3853, %int9223372036854775807_3854, %int1_3855 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3309, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_3856 = torch.constant.int 4 - %int1_3857 = torch.constant.int 1 - %int1_3858 = torch.constant.int 1 - %3310 = torch.prim.ListConstruct %int4_3856, %int1_3857, %int1_3858 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3311 = torch.aten.repeat %3309, %3310 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %3311, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_3859 = torch.constant.int 6 - %3312 = torch.prims.convert_element_type %3285, %int6_3859 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %3312, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %3313 = torch_c.to_builtin_tensor %3312 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %3314 = torch_c.to_builtin_tensor %3311 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %3315 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%3313, %3314) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %3316 = torch_c.from_builtin_tensor %3315 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %3316, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_3860 = torch.constant.int 5 - %3317 = torch.prims.convert_element_type %3316, %int5_3860 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3317, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_3861 = torch.constant.int 131072 - %none_3862 = torch.constant.none - %none_3863 = torch.constant.none - %cpu_3864 = torch.constant.device "cpu" - %false_3865 = torch.constant.bool false - %3318 = torch.aten.arange %int131072_3861, %none_3862, %none_3863, %cpu_3864, %false_3865 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_3866 = torch.constant.int 0 - %int128_3867 = torch.constant.int 128 - %none_3868 = torch.constant.none - %none_3869 = torch.constant.none - %cpu_3870 = torch.constant.device "cpu" - %false_3871 = torch.constant.bool false - %3319 = torch.aten.arange.start %int0_3866, %int128_3867, %none_3868, %none_3869, %cpu_3870, %false_3871 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_3872 = torch.constant.int 2 - %3320 = torch.aten.floor_divide.Scalar %3319, %int2_3872 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_3873 = torch.constant.int 6 - %3321 = torch.prims.convert_element_type %3320, %int6_3873 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_3874 = torch.constant.int 128 - %3322 = torch.aten.div.Scalar %3321, %int128_3874 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_3875 = torch.constant.float 2.000000e+00 - %3323 = torch.aten.mul.Scalar %3322, %float2.000000e00_3875 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_3876 = torch.constant.float 5.000000e+05 - %3324 = torch.aten.pow.Scalar %float5.000000e05_3876, %3323 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %3325 = torch.aten.reciprocal %3324 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_3877 = torch.constant.float 1.000000e+00 - %3326 = torch.aten.mul.Scalar %3325, %float1.000000e00_3877 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %3285 = torch.aten.mean.dim %3283, %3284, %true_3822, %none_3823 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3285, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_3824 = torch.constant.float 9.9999997473787516E-6 + %int1_3825 = torch.constant.int 1 + %3286 = torch.aten.add.Scalar %3285, %float9.999990e-06_3824, %int1_3825 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3286, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %3287 = torch.aten.rsqrt %3286 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3287, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %3288 = torch.aten.mul.Tensor %3282, %3287 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3288, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_3826 = torch.constant.int 5 + %3289 = torch.prims.convert_element_type %3288, %int5_3826 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3289, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %3290 = torch.aten.mul.Tensor %101, %3289 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3290, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_3827 = torch.constant.int 5 + %3291 = torch.prims.convert_element_type %3290, %int5_3827 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3291, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_3828 = torch.constant.int -2 + %int-1_3829 = torch.constant.int -1 + %3292 = torch.aten.transpose.int %102, %int-2_3828, %int-1_3829 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_3830 = torch.constant.int 5 + %3293 = torch.prims.convert_element_type %3292, %int5_3830 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_3831 = torch.constant.int 4096 + %3294 = torch.prim.ListConstruct %342, %int4096_3831 : (!torch.int, !torch.int) -> !torch.list + %3295 = torch.aten.view %3291, %3294 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3295, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3296 = torch.aten.mm %3295, %3293 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3296, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_3832 = torch.constant.int 4 + %int4096_3833 = torch.constant.int 4096 + %3297 = torch.prim.ListConstruct %int4_3832, %298, %int4096_3833 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3298 = torch.aten.view %3296, %3297 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3298, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_3834 = torch.constant.int -2 + %int-1_3835 = torch.constant.int -1 + %3299 = torch.aten.transpose.int %103, %int-2_3834, %int-1_3835 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_3836 = torch.constant.int 5 + %3300 = torch.prims.convert_element_type %3299, %int5_3836 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_3837 = torch.constant.int 4096 + %3301 = torch.prim.ListConstruct %342, %int4096_3837 : (!torch.int, !torch.int) -> !torch.list + %3302 = torch.aten.view %3291, %3301 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3302, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3303 = torch.aten.mm %3302, %3300 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %3303, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_3838 = torch.constant.int 4 + %int1024_3839 = torch.constant.int 1024 + %3304 = torch.prim.ListConstruct %int4_3838, %298, %int1024_3839 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3305 = torch.aten.view %3303, %3304 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %3305, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_3840 = torch.constant.int -2 + %int-1_3841 = torch.constant.int -1 + %3306 = torch.aten.transpose.int %104, %int-2_3840, %int-1_3841 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_3842 = torch.constant.int 5 + %3307 = torch.prims.convert_element_type %3306, %int5_3842 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_3843 = torch.constant.int 4096 + %3308 = torch.prim.ListConstruct %342, %int4096_3843 : (!torch.int, !torch.int) -> !torch.list + %3309 = torch.aten.view %3291, %3308 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3309, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3310 = torch.aten.mm %3309, %3307 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %3310, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_3844 = torch.constant.int 4 + %int1024_3845 = torch.constant.int 1024 + %3311 = torch.prim.ListConstruct %int4_3844, %298, %int1024_3845 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3312 = torch.aten.view %3310, %3311 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %3312, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_3846 = torch.constant.int 4 + %int32_3847 = torch.constant.int 32 + %int128_3848 = torch.constant.int 128 + %3313 = torch.prim.ListConstruct %int4_3846, %298, %int32_3847, %int128_3848 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3314 = torch.aten.view %3298, %3313 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3314, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_3849 = torch.constant.int 4 + %int8_3850 = torch.constant.int 8 + %int128_3851 = torch.constant.int 128 + %3315 = torch.prim.ListConstruct %int4_3849, %298, %int8_3850, %int128_3851 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3316 = torch.aten.view %3305, %3315 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3316, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_3852 = torch.constant.int 4 + %int8_3853 = torch.constant.int 8 + %int128_3854 = torch.constant.int 128 + %3317 = torch.prim.ListConstruct %int4_3852, %298, %int8_3853, %int128_3854 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3318 = torch.aten.view %3312, %3317 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3318, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_3855 = torch.constant.int 131072 + %none_3856 = torch.constant.none + %none_3857 = torch.constant.none + %cpu_3858 = torch.constant.device "cpu" + %false_3859 = torch.constant.bool false + %3319 = torch.aten.arange %int131072_3855, %none_3856, %none_3857, %cpu_3858, %false_3859 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_3860 = torch.constant.int 0 + %int128_3861 = torch.constant.int 128 + %int2_3862 = torch.constant.int 2 + %int4_3863 = torch.constant.int 4 + %none_3864 = torch.constant.none + %cpu_3865 = torch.constant.device "cpu" + %false_3866 = torch.constant.bool false + %3320 = torch.aten.arange.start_step %int0_3860, %int128_3861, %int2_3862, %int4_3863, %none_3864, %cpu_3865, %false_3866 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_3867 = torch.constant.int 6 + %3321 = torch.prims.convert_element_type %3320, %int6_3867 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_3868 = torch.constant.int 128 + %3322 = torch.aten.div.Scalar %3321, %int128_3868 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_3869 = torch.constant.float 5.000000e+05 + %3323 = torch.aten.pow.Scalar %float5.000000e05_3869, %3322 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3324 = torch.aten.reciprocal %3323 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_3870 = torch.constant.float 1.000000e+00 + %3325 = torch.aten.mul.Scalar %3324, %float1.000000e00_3870 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %3326 = torch.aten.reciprocal %3325 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_3871 = torch.constant.float 6.2831853071795862 + %3327 = torch.aten.mul.Scalar %3326, %float6.283190e00_3871 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_3872 = torch.constant.float 8.192000e+03 + %3328 = torch.aten.gt.Scalar %3327, %float8.192000e03_3872 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_3873 = torch.constant.int 8 + %3329 = torch.aten.div.Scalar %3325, %int8_3873 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3330 = torch.aten.where.self %3328, %3329, %3325 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3331 = torch.aten.reciprocal %3327 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_3874 = torch.constant.int 8192 + %3332 = torch.aten.mul.Scalar %3331, %int8192_3874 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_3875 = torch.constant.int 1 + %int1_3876 = torch.constant.int 1 + %3333 = torch.aten.sub.Scalar %3332, %int1_3875, %int1_3876 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_3877 = torch.constant.int 3 + %3334 = torch.aten.div.Scalar %3333, %int3_3877 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_3878 = torch.constant.int 1 - %3327 = torch.aten.unsqueeze %3318, %int1_3878 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_3879 = torch.constant.int 0 - %3328 = torch.aten.unsqueeze %3326, %int0_3879 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %3329 = torch.aten.mul.Tensor %3327, %3328 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_3880 = torch.constant.int 1 - %3330 = torch.aten.size.int %3276, %int1_3880 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_3881 = torch.constant.int 0 - %3331 = torch.aten.add.int %int0_3881, %3330 : !torch.int, !torch.int -> !torch.int - %int0_3882 = torch.constant.int 0 - %int0_3883 = torch.constant.int 0 - %int1_3884 = torch.constant.int 1 - %3332 = torch.aten.slice.Tensor %3329, %int0_3882, %int0_3883, %3331, %int1_3884 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3332, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_3885 = torch.constant.int 1 - %int0_3886 = torch.constant.int 0 - %int9223372036854775807_3887 = torch.constant.int 9223372036854775807 - %int1_3888 = torch.constant.int 1 - %3333 = torch.aten.slice.Tensor %3332, %int1_3885, %int0_3886, %int9223372036854775807_3887, %int1_3888 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3333, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_3889 = torch.constant.int 1 - %int0_3890 = torch.constant.int 0 - %int9223372036854775807_3891 = torch.constant.int 9223372036854775807 - %int1_3892 = torch.constant.int 1 - %3334 = torch.aten.slice.Tensor %3333, %int1_3889, %int0_3890, %int9223372036854775807_3891, %int1_3892 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3334, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %int1_3879 = torch.constant.int 1 + %3335 = torch.aten.rsub.Scalar %3334, %int1_3878, %int1_3879 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %3336 = torch.aten.mul.Tensor %3335, %3330 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_3880 = torch.constant.int 8 + %3337 = torch.aten.div.Scalar %3336, %int8_3880 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3338 = torch.aten.mul.Tensor %3334, %3330 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_3881 = torch.constant.int 1 + %3339 = torch.aten.add.Tensor %3337, %3338, %int1_3881 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_3882 = torch.constant.float 2.048000e+03 + %3340 = torch.aten.lt.Scalar %3327, %float2.048000e03_3882 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3341 = torch.aten.bitwise_not %3340 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_3883 = torch.constant.float 8.192000e+03 + %3342 = torch.aten.gt.Scalar %3327, %float8.192000e03_3883 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3343 = torch.aten.bitwise_not %3342 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3344 = torch.aten.mul.Tensor %3341, %3343 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3345 = torch.aten.where.self %3344, %3339, %3330 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3346 = torch.prim.ListConstruct %3345, %3345 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_3884 = torch.constant.int -1 + %3347 = torch.aten.cat %3346, %int-1_3884 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_3885 = torch.constant.int 6 + %3348 = torch.prims.convert_element_type %3347, %int6_3885 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_3886 = torch.constant.int 1 + %3349 = torch.aten.unsqueeze %3319, %int1_3886 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_3887 = torch.constant.int 6 + %3350 = torch.prims.convert_element_type %3349, %int6_3887 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_3888 = torch.constant.int 0 + %3351 = torch.aten.unsqueeze %3348, %int0_3888 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_3889 = torch.constant.int 6 + %3352 = torch.prims.convert_element_type %3351, %int6_3889 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %3353 = torch.aten.mul.Tensor %3350, %3352 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %3354 = torch.aten.cos %3353 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_3890 = torch.constant.int 5 + %3355 = torch.prims.convert_element_type %3354, %int5_3890 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %3356 = torch.aten.sin %3353 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_3891 = torch.constant.int 5 + %3357 = torch.prims.convert_element_type %3356, %int5_3891 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_3892 = torch.constant.int 0 %int0_3893 = torch.constant.int 0 - %3335 = torch.aten.unsqueeze %3334, %int0_3893 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3335, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> %int1_3894 = torch.constant.int 1 - %int0_3895 = torch.constant.int 0 - %int9223372036854775807_3896 = torch.constant.int 9223372036854775807 - %int1_3897 = torch.constant.int 1 - %3336 = torch.aten.slice.Tensor %3335, %int1_3894, %int0_3895, %int9223372036854775807_3896, %int1_3897 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3336, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_3898 = torch.constant.int 2 + %3358 = torch.aten.slice.Tensor %3355, %int0_3892, %int0_3893, %298, %int1_3894 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3358, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_3895 = torch.constant.int 1 + %int0_3896 = torch.constant.int 0 + %int9223372036854775807_3897 = torch.constant.int 9223372036854775807 + %int1_3898 = torch.constant.int 1 + %3359 = torch.aten.slice.Tensor %3358, %int1_3895, %int0_3896, %int9223372036854775807_3897, %int1_3898 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3359, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int0_3899 = torch.constant.int 0 - %int9223372036854775807_3900 = torch.constant.int 9223372036854775807 + %int0_3900 = torch.constant.int 0 %int1_3901 = torch.constant.int 1 - %3337 = torch.aten.slice.Tensor %3336, %int2_3898, %int0_3899, %int9223372036854775807_3900, %int1_3901 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3337, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_3902 = torch.constant.int 4 - %int1_3903 = torch.constant.int 1 - %int1_3904 = torch.constant.int 1 - %3338 = torch.prim.ListConstruct %int4_3902, %int1_3903, %int1_3904 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3339 = torch.aten.repeat %3337, %3338 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %3339, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_3905 = torch.constant.int 6 - %3340 = torch.prims.convert_element_type %3287, %int6_3905 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %3340, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %3341 = torch_c.to_builtin_tensor %3340 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %3342 = torch_c.to_builtin_tensor %3339 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %3343 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%3341, %3342) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %3344 = torch_c.from_builtin_tensor %3343 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %3344, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_3906 = torch.constant.int 5 - %3345 = torch.prims.convert_element_type %3344, %int5_3906 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3345, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_3907 = torch.constant.int 64 - %3346 = torch.aten.mul.Scalar %arg2, %int64_3907 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3346, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int30 = torch.constant.int 30 - %int1_3908 = torch.constant.int 1 - %3347 = torch.aten.add.Scalar %3346, %int30, %int1_3908 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3347, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_3909 = torch.constant.int 4 - %int32_3910 = torch.constant.int 32 - %int8_3911 = torch.constant.int 8 - %int128_3912 = torch.constant.int 128 - %3348 = torch.prim.ListConstruct %int4_3909, %398, %int32_3910, %int8_3911, %int128_3912 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3349 = torch.aten.view %3345, %3348 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3349, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_3913 = torch.constant.int 4 - %3350 = torch.aten.mul.int %int4_3913, %398 : !torch.int, !torch.int -> !torch.int - %int32_3914 = torch.constant.int 32 - %int8_3915 = torch.constant.int 8 - %int128_3916 = torch.constant.int 128 - %3351 = torch.prim.ListConstruct %3350, %int32_3914, %int8_3915, %int128_3916 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3352 = torch.aten.view %3349, %3351 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3352, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_3917 = torch.constant.int 4 - %3353 = torch.aten.mul.int %int4_3917, %398 : !torch.int, !torch.int -> !torch.int - %3354 = torch.prim.ListConstruct %3353 : (!torch.int) -> !torch.list - %3355 = torch.aten.view %3347, %3354 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3355, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_3918 = torch.constant.int 32 - %int2_3919 = torch.constant.int 2 - %int32_3920 = torch.constant.int 32 - %int8_3921 = torch.constant.int 8 - %int128_3922 = torch.constant.int 128 - %3356 = torch.prim.ListConstruct %389, %int32_3918, %int2_3919, %int32_3920, %int8_3921, %int128_3922 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3357 = torch.aten.view %3189, %3356 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3357, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_3923 = torch.constant.int 32 - %3358 = torch.aten.mul.int %389, %int32_3923 : !torch.int, !torch.int -> !torch.int - %int2_3924 = torch.constant.int 2 - %3359 = torch.aten.mul.int %3358, %int2_3924 : !torch.int, !torch.int -> !torch.int - %int32_3925 = torch.constant.int 32 - %int8_3926 = torch.constant.int 8 - %int128_3927 = torch.constant.int 128 - %3360 = torch.prim.ListConstruct %3359, %int32_3925, %int8_3926, %int128_3927 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3361 = torch.aten.view %3357, %3360 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3361, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %3362 = torch.prim.ListConstruct %3355 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_3928 = torch.constant.bool false - %3363 = torch.aten.index_put %3361, %3362, %3352, %false_3928 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3363, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_3929 = torch.constant.int 32 - %int2_3930 = torch.constant.int 2 - %int32_3931 = torch.constant.int 32 - %int8_3932 = torch.constant.int 8 - %int128_3933 = torch.constant.int 128 - %3364 = torch.prim.ListConstruct %389, %int32_3929, %int2_3930, %int32_3931, %int8_3932, %int128_3933 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3365 = torch.aten.view %3363, %3364 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3365, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3934 = torch.constant.int 2097152 - %3366 = torch.prim.ListConstruct %389, %int2097152_3934 : (!torch.int, !torch.int) -> !torch.list - %3367 = torch.aten.view %3365, %3366 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3367, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_3935 = torch.constant.int 32 - %int2_3936 = torch.constant.int 2 - %int32_3937 = torch.constant.int 32 - %int8_3938 = torch.constant.int 8 - %int128_3939 = torch.constant.int 128 - %3368 = torch.prim.ListConstruct %389, %int32_3935, %int2_3936, %int32_3937, %int8_3938, %int128_3939 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3369 = torch.aten.view %3367, %3368 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3369, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_3940 = torch.constant.int 32 - %int8_3941 = torch.constant.int 8 - %int128_3942 = torch.constant.int 128 - %3370 = torch.prim.ListConstruct %3359, %int32_3940, %int8_3941, %int128_3942 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3371 = torch.aten.view %3369, %3370 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3371, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_3943 = torch.constant.int 4 - %int32_3944 = torch.constant.int 32 - %int8_3945 = torch.constant.int 8 - %int128_3946 = torch.constant.int 128 - %3372 = torch.prim.ListConstruct %int4_3943, %398, %int32_3944, %int8_3945, %int128_3946 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3373 = torch.aten.view %3289, %3372 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3373, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_3947 = torch.constant.int 4 - %3374 = torch.aten.mul.int %int4_3947, %398 : !torch.int, !torch.int -> !torch.int - %int32_3948 = torch.constant.int 32 - %int8_3949 = torch.constant.int 8 + %3360 = torch.aten.slice.Tensor %3357, %int0_3899, %int0_3900, %298, %int1_3901 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3360, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_3902 = torch.constant.int 1 + %int0_3903 = torch.constant.int 0 + %int9223372036854775807_3904 = torch.constant.int 9223372036854775807 + %int1_3905 = torch.constant.int 1 + %3361 = torch.aten.slice.Tensor %3360, %int1_3902, %int0_3903, %int9223372036854775807_3904, %int1_3905 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3361, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_3906 = torch.constant.int 0 + %3362 = torch.aten.unsqueeze %3359, %int0_3906 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3362, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_3907 = torch.constant.int 1 + %int0_3908 = torch.constant.int 0 + %int9223372036854775807_3909 = torch.constant.int 9223372036854775807 + %int1_3910 = torch.constant.int 1 + %3363 = torch.aten.slice.Tensor %3362, %int1_3907, %int0_3908, %int9223372036854775807_3909, %int1_3910 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3363, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_3911 = torch.constant.int 2 + %3364 = torch.aten.unsqueeze %3363, %int2_3911 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3364, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_3912 = torch.constant.int 3 + %int0_3913 = torch.constant.int 0 + %int9223372036854775807_3914 = torch.constant.int 9223372036854775807 + %int1_3915 = torch.constant.int 1 + %3365 = torch.aten.slice.Tensor %3364, %int3_3912, %int0_3913, %int9223372036854775807_3914, %int1_3915 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3365, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_3916 = torch.constant.int 4 + %int1_3917 = torch.constant.int 1 + %int1_3918 = torch.constant.int 1 + %int1_3919 = torch.constant.int 1 + %3366 = torch.prim.ListConstruct %int4_3916, %int1_3917, %int1_3918, %int1_3919 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3367 = torch.aten.repeat %3365, %3366 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3367, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_3920 = torch.constant.int 0 + %3368 = torch.aten.unsqueeze %3361, %int0_3920 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3368, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_3921 = torch.constant.int 1 + %int0_3922 = torch.constant.int 0 + %int9223372036854775807_3923 = torch.constant.int 9223372036854775807 + %int1_3924 = torch.constant.int 1 + %3369 = torch.aten.slice.Tensor %3368, %int1_3921, %int0_3922, %int9223372036854775807_3923, %int1_3924 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3369, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_3925 = torch.constant.int 2 + %3370 = torch.aten.unsqueeze %3369, %int2_3925 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3370, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_3926 = torch.constant.int 3 + %int0_3927 = torch.constant.int 0 + %int9223372036854775807_3928 = torch.constant.int 9223372036854775807 + %int1_3929 = torch.constant.int 1 + %3371 = torch.aten.slice.Tensor %3370, %int3_3926, %int0_3927, %int9223372036854775807_3928, %int1_3929 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3371, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_3930 = torch.constant.int 4 + %int1_3931 = torch.constant.int 1 + %int1_3932 = torch.constant.int 1 + %int1_3933 = torch.constant.int 1 + %3372 = torch.prim.ListConstruct %int4_3930, %int1_3931, %int1_3932, %int1_3933 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3373 = torch.aten.repeat %3371, %3372 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3373, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %3374 = torch.aten.mul.Tensor %3314, %3367 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3374, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_3934 = torch.constant.int 3 + %int0_3935 = torch.constant.int 0 + %int64_3936 = torch.constant.int 64 + %int1_3937 = torch.constant.int 1 + %3375 = torch.aten.slice.Tensor %3314, %int3_3934, %int0_3935, %int64_3936, %int1_3937 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %3375, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_3938 = torch.constant.int 3 + %int64_3939 = torch.constant.int 64 + %int9223372036854775807_3940 = torch.constant.int 9223372036854775807 + %int1_3941 = torch.constant.int 1 + %3376 = torch.aten.slice.Tensor %3314, %int3_3938, %int64_3939, %int9223372036854775807_3940, %int1_3941 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %3376, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %3377 = torch.aten.neg %3376 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %3377, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %3378 = torch.prim.ListConstruct %3377, %3375 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_3942 = torch.constant.int -1 + %3379 = torch.aten.cat %3378, %int-1_3942 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3379, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %3380 = torch.aten.mul.Tensor %3379, %3373 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3380, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_3943 = torch.constant.int 1 + %3381 = torch.aten.add.Tensor %3374, %3380, %int1_3943 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3381, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_3944 = torch.constant.int 131072 + %none_3945 = torch.constant.none + %none_3946 = torch.constant.none + %cpu_3947 = torch.constant.device "cpu" + %false_3948 = torch.constant.bool false + %3382 = torch.aten.arange %int131072_3944, %none_3945, %none_3946, %cpu_3947, %false_3948 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_3949 = torch.constant.int 0 %int128_3950 = torch.constant.int 128 - %3375 = torch.prim.ListConstruct %3374, %int32_3948, %int8_3949, %int128_3950 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3376 = torch.aten.view %3373, %3375 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3376, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_3951 = torch.constant.int 1 - %int1_3952 = torch.constant.int 1 - %3377 = torch.aten.add.Scalar %3347, %int1_3951, %int1_3952 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3377, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_3953 = torch.constant.int 4 - %3378 = torch.aten.mul.int %int4_3953, %398 : !torch.int, !torch.int -> !torch.int - %3379 = torch.prim.ListConstruct %3378 : (!torch.int) -> !torch.list - %3380 = torch.aten.view %3377, %3379 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3380, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %3381 = torch.prim.ListConstruct %3380 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_3954 = torch.constant.bool false - %3382 = torch.aten.index_put %3371, %3381, %3376, %false_3954 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3382, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_3955 = torch.constant.int 32 - %int2_3956 = torch.constant.int 2 - %int32_3957 = torch.constant.int 32 - %int8_3958 = torch.constant.int 8 - %int128_3959 = torch.constant.int 128 - %3383 = torch.prim.ListConstruct %389, %int32_3955, %int2_3956, %int32_3957, %int8_3958, %int128_3959 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3384 = torch.aten.view %3382, %3383 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3384, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3960 = torch.constant.int 2097152 - %3385 = torch.prim.ListConstruct %389, %int2097152_3960 : (!torch.int, !torch.int) -> !torch.list - %3386 = torch.aten.view %3384, %3385 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3386, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_3961 = torch.constant.int -2 - %3387 = torch.aten.unsqueeze %3345, %int-2_3961 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3387, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_3962 = torch.constant.int 4 - %int8_3963 = torch.constant.int 8 - %int4_3964 = torch.constant.int 4 - %int128_3965 = torch.constant.int 128 - %3388 = torch.prim.ListConstruct %int4_3962, %3330, %int8_3963, %int4_3964, %int128_3965 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_3966 = torch.constant.bool false - %3389 = torch.aten.expand %3387, %3388, %false_3966 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3389, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_3967 = torch.constant.int 0 - %3390 = torch.aten.clone %3389, %int0_3967 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3390, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_3968 = torch.constant.int 4 - %int32_3969 = torch.constant.int 32 - %int128_3970 = torch.constant.int 128 - %3391 = torch.prim.ListConstruct %int4_3968, %3330, %int32_3969, %int128_3970 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3392 = torch.aten._unsafe_view %3390, %3391 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3392, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_3971 = torch.constant.int -2 - %3393 = torch.aten.unsqueeze %3289, %int-2_3971 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3393, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_3972 = torch.constant.int 1 - %3394 = torch.aten.size.int %3283, %int1_3972 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_3973 = torch.constant.int 4 - %int8_3974 = torch.constant.int 8 - %int4_3975 = torch.constant.int 4 - %int128_3976 = torch.constant.int 128 - %3395 = torch.prim.ListConstruct %int4_3973, %3394, %int8_3974, %int4_3975, %int128_3976 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_3977 = torch.constant.bool false - %3396 = torch.aten.expand %3393, %3395, %false_3977 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3396, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_3978 = torch.constant.int 0 - %3397 = torch.aten.clone %3396, %int0_3978 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3397, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_3979 = torch.constant.int 4 - %int32_3980 = torch.constant.int 32 - %int128_3981 = torch.constant.int 128 - %3398 = torch.prim.ListConstruct %int4_3979, %3394, %int32_3980, %int128_3981 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3399 = torch.aten._unsafe_view %3397, %3398 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3399, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_3982 = torch.constant.int 1 - %int2_3983 = torch.constant.int 2 - %3400 = torch.aten.transpose.int %3317, %int1_3982, %int2_3983 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3400, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int2_3951 = torch.constant.int 2 + %int4_3952 = torch.constant.int 4 + %none_3953 = torch.constant.none + %cpu_3954 = torch.constant.device "cpu" + %false_3955 = torch.constant.bool false + %3383 = torch.aten.arange.start_step %int0_3949, %int128_3950, %int2_3951, %int4_3952, %none_3953, %cpu_3954, %false_3955 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_3956 = torch.constant.int 6 + %3384 = torch.prims.convert_element_type %3383, %int6_3956 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_3957 = torch.constant.int 128 + %3385 = torch.aten.div.Scalar %3384, %int128_3957 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_3958 = torch.constant.float 5.000000e+05 + %3386 = torch.aten.pow.Scalar %float5.000000e05_3958, %3385 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3387 = torch.aten.reciprocal %3386 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_3959 = torch.constant.float 1.000000e+00 + %3388 = torch.aten.mul.Scalar %3387, %float1.000000e00_3959 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %3389 = torch.aten.reciprocal %3388 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_3960 = torch.constant.float 6.2831853071795862 + %3390 = torch.aten.mul.Scalar %3389, %float6.283190e00_3960 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_3961 = torch.constant.float 8.192000e+03 + %3391 = torch.aten.gt.Scalar %3390, %float8.192000e03_3961 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_3962 = torch.constant.int 8 + %3392 = torch.aten.div.Scalar %3388, %int8_3962 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3393 = torch.aten.where.self %3391, %3392, %3388 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3394 = torch.aten.reciprocal %3390 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_3963 = torch.constant.int 8192 + %3395 = torch.aten.mul.Scalar %3394, %int8192_3963 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_3964 = torch.constant.int 1 + %int1_3965 = torch.constant.int 1 + %3396 = torch.aten.sub.Scalar %3395, %int1_3964, %int1_3965 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_3966 = torch.constant.int 3 + %3397 = torch.aten.div.Scalar %3396, %int3_3966 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_3967 = torch.constant.int 1 + %int1_3968 = torch.constant.int 1 + %3398 = torch.aten.rsub.Scalar %3397, %int1_3967, %int1_3968 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %3399 = torch.aten.mul.Tensor %3398, %3393 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_3969 = torch.constant.int 8 + %3400 = torch.aten.div.Scalar %3399, %int8_3969 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3401 = torch.aten.mul.Tensor %3397, %3393 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_3970 = torch.constant.int 1 + %3402 = torch.aten.add.Tensor %3400, %3401, %int1_3970 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_3971 = torch.constant.float 2.048000e+03 + %3403 = torch.aten.lt.Scalar %3390, %float2.048000e03_3971 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3404 = torch.aten.bitwise_not %3403 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_3972 = torch.constant.float 8.192000e+03 + %3405 = torch.aten.gt.Scalar %3390, %float8.192000e03_3972 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3406 = torch.aten.bitwise_not %3405 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3407 = torch.aten.mul.Tensor %3404, %3406 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3408 = torch.aten.where.self %3407, %3402, %3393 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3409 = torch.prim.ListConstruct %3408, %3408 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_3973 = torch.constant.int -1 + %3410 = torch.aten.cat %3409, %int-1_3973 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_3974 = torch.constant.int 6 + %3411 = torch.prims.convert_element_type %3410, %int6_3974 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_3975 = torch.constant.int 1 + %3412 = torch.aten.unsqueeze %3382, %int1_3975 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_3976 = torch.constant.int 6 + %3413 = torch.prims.convert_element_type %3412, %int6_3976 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_3977 = torch.constant.int 0 + %3414 = torch.aten.unsqueeze %3411, %int0_3977 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_3978 = torch.constant.int 6 + %3415 = torch.prims.convert_element_type %3414, %int6_3978 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %3416 = torch.aten.mul.Tensor %3413, %3415 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %3417 = torch.aten.cos %3416 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_3979 = torch.constant.int 5 + %3418 = torch.prims.convert_element_type %3417, %int5_3979 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %3419 = torch.aten.sin %3416 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_3980 = torch.constant.int 5 + %3420 = torch.prims.convert_element_type %3419, %int5_3980 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_3981 = torch.constant.int 0 + %int0_3982 = torch.constant.int 0 + %int1_3983 = torch.constant.int 1 + %3421 = torch.aten.slice.Tensor %3418, %int0_3981, %int0_3982, %298, %int1_3983 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3421, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_3984 = torch.constant.int 1 - %int2_3985 = torch.constant.int 2 - %3401 = torch.aten.transpose.int %3392, %int1_3984, %int2_3985 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3401, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_3986 = torch.constant.int 1 - %int2_3987 = torch.constant.int 2 - %3402 = torch.aten.transpose.int %3399, %int1_3986, %int2_3987 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3402, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_3988 = torch.constant.float 0.000000e+00 - %true_3989 = torch.constant.bool true - %none_3990 = torch.constant.none - %none_3991 = torch.constant.none - %3403:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3400, %3401, %3402, %float0.000000e00_3988, %true_3989, %none_3990, %none_3991) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %3403#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_3992 = torch.constant.int 1 - %int2_3993 = torch.constant.int 2 - %3404 = torch.aten.transpose.int %3403#0, %int1_3992, %int2_3993 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3404, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_3994 = torch.constant.int 4 - %int4096_3995 = torch.constant.int 4096 - %3405 = torch.prim.ListConstruct %int4_3994, %3302, %int4096_3995 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3406 = torch.aten.view %3404, %3405 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3406, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_3996 = torch.constant.int -2 - %int-1_3997 = torch.constant.int -1 - %3407 = torch.aten.transpose.int %140, %int-2_3996, %int-1_3997 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3998 = torch.constant.int 4 - %3408 = torch.aten.mul.int %int4_3998, %3302 : !torch.int, !torch.int -> !torch.int - %int4096_3999 = torch.constant.int 4096 - %3409 = torch.prim.ListConstruct %3408, %int4096_3999 : (!torch.int, !torch.int) -> !torch.list - %3410 = torch.aten.view %3406, %3409 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3410, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3411 = torch.aten.mm %3410, %3407 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3411, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_4000 = torch.constant.int 4 - %int4096_4001 = torch.constant.int 4096 - %3412 = torch.prim.ListConstruct %int4_4000, %3302, %int4096_4001 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3413 = torch.aten.view %3411, %3412 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3413, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_4002 = torch.constant.int 1 - %3414 = torch.aten.add.Tensor %3252, %3413, %int1_4002 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3414, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_4003 = torch.constant.int 6 - %3415 = torch.prims.convert_element_type %3414, %int6_4003 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3415, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_4004 = torch.constant.int 2 - %3416 = torch.aten.pow.Tensor_Scalar %3415, %int2_4004 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3416, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_4005 = torch.constant.int -1 - %3417 = torch.prim.ListConstruct %int-1_4005 : (!torch.int) -> !torch.list - %true_4006 = torch.constant.bool true - %none_4007 = torch.constant.none - %3418 = torch.aten.mean.dim %3416, %3417, %true_4006, %none_4007 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3418, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_4008 = torch.constant.float 9.9999997473787516E-6 - %int1_4009 = torch.constant.int 1 - %3419 = torch.aten.add.Scalar %3418, %float9.999990e-06_4008, %int1_4009 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3419, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3420 = torch.aten.rsqrt %3419 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3420, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3421 = torch.aten.mul.Tensor %3415, %3420 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3421, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4010 = torch.constant.int 5 - %3422 = torch.prims.convert_element_type %3421, %int5_4010 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3422, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %3423 = torch.aten.mul.Tensor %141, %3422 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3423, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4011 = torch.constant.int 5 - %3424 = torch.prims.convert_element_type %3423, %int5_4011 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3424, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4012 = torch.constant.int -2 - %int-1_4013 = torch.constant.int -1 - %3425 = torch.aten.transpose.int %142, %int-2_4012, %int-1_4013 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_4014 = torch.constant.int 4 - %3426 = torch.aten.mul.int %int4_4014, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4015 = torch.constant.int 4096 - %3427 = torch.prim.ListConstruct %3426, %int4096_4015 : (!torch.int, !torch.int) -> !torch.list - %3428 = torch.aten.view %3424, %3427 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3428, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3429 = torch.aten.mm %3428, %3425 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3429, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_4016 = torch.constant.int 4 - %int14336_4017 = torch.constant.int 14336 - %3430 = torch.prim.ListConstruct %int4_4016, %306, %int14336_4017 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3431 = torch.aten.view %3429, %3430 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3431, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %3432 = torch.aten.silu %3431 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3432, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_4018 = torch.constant.int -2 - %int-1_4019 = torch.constant.int -1 - %3433 = torch.aten.transpose.int %143, %int-2_4018, %int-1_4019 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_4020 = torch.constant.int 4 - %3434 = torch.aten.mul.int %int4_4020, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4021 = torch.constant.int 4096 - %3435 = torch.prim.ListConstruct %3434, %int4096_4021 : (!torch.int, !torch.int) -> !torch.list - %3436 = torch.aten.view %3424, %3435 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3436, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3437 = torch.aten.mm %3436, %3433 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3437, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_4022 = torch.constant.int 4 - %int14336_4023 = torch.constant.int 14336 - %3438 = torch.prim.ListConstruct %int4_4022, %306, %int14336_4023 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3439 = torch.aten.view %3437, %3438 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3439, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %3440 = torch.aten.mul.Tensor %3432, %3439 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3440, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_4024 = torch.constant.int -2 - %int-1_4025 = torch.constant.int -1 - %3441 = torch.aten.transpose.int %144, %int-2_4024, %int-1_4025 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int0_3985 = torch.constant.int 0 + %int9223372036854775807_3986 = torch.constant.int 9223372036854775807 + %int1_3987 = torch.constant.int 1 + %3422 = torch.aten.slice.Tensor %3421, %int1_3984, %int0_3985, %int9223372036854775807_3986, %int1_3987 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3422, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_3988 = torch.constant.int 0 + %int0_3989 = torch.constant.int 0 + %int1_3990 = torch.constant.int 1 + %3423 = torch.aten.slice.Tensor %3420, %int0_3988, %int0_3989, %298, %int1_3990 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3423, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_3991 = torch.constant.int 1 + %int0_3992 = torch.constant.int 0 + %int9223372036854775807_3993 = torch.constant.int 9223372036854775807 + %int1_3994 = torch.constant.int 1 + %3424 = torch.aten.slice.Tensor %3423, %int1_3991, %int0_3992, %int9223372036854775807_3993, %int1_3994 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3424, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_3995 = torch.constant.int 0 + %3425 = torch.aten.unsqueeze %3422, %int0_3995 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3425, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_3996 = torch.constant.int 1 + %int0_3997 = torch.constant.int 0 + %int9223372036854775807_3998 = torch.constant.int 9223372036854775807 + %int1_3999 = torch.constant.int 1 + %3426 = torch.aten.slice.Tensor %3425, %int1_3996, %int0_3997, %int9223372036854775807_3998, %int1_3999 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3426, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_4000 = torch.constant.int 2 + %3427 = torch.aten.unsqueeze %3426, %int2_4000 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3427, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_4001 = torch.constant.int 3 + %int0_4002 = torch.constant.int 0 + %int9223372036854775807_4003 = torch.constant.int 9223372036854775807 + %int1_4004 = torch.constant.int 1 + %3428 = torch.aten.slice.Tensor %3427, %int3_4001, %int0_4002, %int9223372036854775807_4003, %int1_4004 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3428, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_4005 = torch.constant.int 4 + %int1_4006 = torch.constant.int 1 + %int1_4007 = torch.constant.int 1 + %int1_4008 = torch.constant.int 1 + %3429 = torch.prim.ListConstruct %int4_4005, %int1_4006, %int1_4007, %int1_4008 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3430 = torch.aten.repeat %3428, %3429 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3430, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_4009 = torch.constant.int 0 + %3431 = torch.aten.unsqueeze %3424, %int0_4009 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3431, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_4010 = torch.constant.int 1 + %int0_4011 = torch.constant.int 0 + %int9223372036854775807_4012 = torch.constant.int 9223372036854775807 + %int1_4013 = torch.constant.int 1 + %3432 = torch.aten.slice.Tensor %3431, %int1_4010, %int0_4011, %int9223372036854775807_4012, %int1_4013 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3432, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_4014 = torch.constant.int 2 + %3433 = torch.aten.unsqueeze %3432, %int2_4014 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3433, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_4015 = torch.constant.int 3 + %int0_4016 = torch.constant.int 0 + %int9223372036854775807_4017 = torch.constant.int 9223372036854775807 + %int1_4018 = torch.constant.int 1 + %3434 = torch.aten.slice.Tensor %3433, %int3_4015, %int0_4016, %int9223372036854775807_4017, %int1_4018 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3434, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_4019 = torch.constant.int 4 + %int1_4020 = torch.constant.int 1 + %int1_4021 = torch.constant.int 1 + %int1_4022 = torch.constant.int 1 + %3435 = torch.prim.ListConstruct %int4_4019, %int1_4020, %int1_4021, %int1_4022 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3436 = torch.aten.repeat %3434, %3435 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3436, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %3437 = torch.aten.mul.Tensor %3316, %3430 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3437, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_4023 = torch.constant.int 3 + %int0_4024 = torch.constant.int 0 + %int64_4025 = torch.constant.int 64 %int1_4026 = torch.constant.int 1 - %3442 = torch.aten.size.int %3431, %int1_4026 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_4027 = torch.constant.int 4 - %3443 = torch.aten.mul.int %int4_4027, %3442 : !torch.int, !torch.int -> !torch.int - %int14336_4028 = torch.constant.int 14336 - %3444 = torch.prim.ListConstruct %3443, %int14336_4028 : (!torch.int, !torch.int) -> !torch.list - %3445 = torch.aten.view %3440, %3444 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3445, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %3446 = torch.aten.mm %3445, %3441 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3446, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_4029 = torch.constant.int 4 - %int4096_4030 = torch.constant.int 4096 - %3447 = torch.prim.ListConstruct %int4_4029, %3442, %int4096_4030 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3448 = torch.aten.view %3446, %3447 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3448, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_4031 = torch.constant.int 1 - %3449 = torch.aten.add.Tensor %3414, %3448, %int1_4031 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3449, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_4032 = torch.constant.int 6 - %3450 = torch.prims.convert_element_type %3449, %int6_4032 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3450, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_4033 = torch.constant.int 2 - %3451 = torch.aten.pow.Tensor_Scalar %3450, %int2_4033 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3451, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_4034 = torch.constant.int -1 - %3452 = torch.prim.ListConstruct %int-1_4034 : (!torch.int) -> !torch.list - %true_4035 = torch.constant.bool true - %none_4036 = torch.constant.none - %3453 = torch.aten.mean.dim %3451, %3452, %true_4035, %none_4036 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3453, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_4037 = torch.constant.float 9.9999997473787516E-6 + %3438 = torch.aten.slice.Tensor %3316, %int3_4023, %int0_4024, %int64_4025, %int1_4026 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %3438, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_4027 = torch.constant.int 3 + %int64_4028 = torch.constant.int 64 + %int9223372036854775807_4029 = torch.constant.int 9223372036854775807 + %int1_4030 = torch.constant.int 1 + %3439 = torch.aten.slice.Tensor %3316, %int3_4027, %int64_4028, %int9223372036854775807_4029, %int1_4030 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %3439, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %3440 = torch.aten.neg %3439 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %3440, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %3441 = torch.prim.ListConstruct %3440, %3438 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_4031 = torch.constant.int -1 + %3442 = torch.aten.cat %3441, %int-1_4031 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3442, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %3443 = torch.aten.mul.Tensor %3442, %3436 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3443, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_4032 = torch.constant.int 1 + %3444 = torch.aten.add.Tensor %3437, %3443, %int1_4032 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3444, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_4033 = torch.constant.int 32 + %3445 = torch.aten.mul.Scalar %arg2, %int32_4033 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3445, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int11_4034 = torch.constant.int 11 + %int1_4035 = torch.constant.int 1 + %3446 = torch.aten.add.Scalar %3445, %int11_4034, %int1_4035 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3446, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_4036 = torch.constant.int 2 + %3447 = torch.aten.mul.Scalar %3446, %int2_4036 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3447, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_4037 = torch.constant.int 0 %int1_4038 = torch.constant.int 1 - %3454 = torch.aten.add.Scalar %3453, %float9.999990e-06_4037, %int1_4038 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3454, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3455 = torch.aten.rsqrt %3454 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3455, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3456 = torch.aten.mul.Tensor %3450, %3455 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3456, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4039 = torch.constant.int 5 - %3457 = torch.prims.convert_element_type %3456, %int5_4039 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3457, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %3458 = torch.aten.mul.Tensor %145, %3457 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3458, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4040 = torch.constant.int 5 - %3459 = torch.prims.convert_element_type %3458, %int5_4040 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3459, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4041 = torch.constant.int -2 - %int-1_4042 = torch.constant.int -1 - %3460 = torch.aten.transpose.int %146, %int-2_4041, %int-1_4042 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_4043 = torch.constant.int 4 - %3461 = torch.aten.mul.int %int4_4043, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4044 = torch.constant.int 4096 - %3462 = torch.prim.ListConstruct %3461, %int4096_4044 : (!torch.int, !torch.int) -> !torch.list - %3463 = torch.aten.view %3459, %3462 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3463, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3464 = torch.aten.mm %3463, %3460 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3464, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_4045 = torch.constant.int 4 - %int4096_4046 = torch.constant.int 4096 - %3465 = torch.prim.ListConstruct %int4_4045, %306, %int4096_4046 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3466 = torch.aten.view %3464, %3465 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3466, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4047 = torch.constant.int -2 - %int-1_4048 = torch.constant.int -1 - %3467 = torch.aten.transpose.int %147, %int-2_4047, %int-1_4048 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_4049 = torch.constant.int 4 - %3468 = torch.aten.mul.int %int4_4049, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4050 = torch.constant.int 4096 - %3469 = torch.prim.ListConstruct %3468, %int4096_4050 : (!torch.int, !torch.int) -> !torch.list - %3470 = torch.aten.view %3459, %3469 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3470, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3471 = torch.aten.mm %3470, %3467 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %3471, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_4051 = torch.constant.int 4 - %int1024_4052 = torch.constant.int 1024 - %3472 = torch.prim.ListConstruct %int4_4051, %306, %int1024_4052 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3473 = torch.aten.view %3471, %3472 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %3473, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_4053 = torch.constant.int -2 - %int-1_4054 = torch.constant.int -1 - %3474 = torch.aten.transpose.int %148, %int-2_4053, %int-1_4054 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_4055 = torch.constant.int 4 - %3475 = torch.aten.mul.int %int4_4055, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4056 = torch.constant.int 4096 - %3476 = torch.prim.ListConstruct %3475, %int4096_4056 : (!torch.int, !torch.int) -> !torch.list - %3477 = torch.aten.view %3459, %3476 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3477, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3478 = torch.aten.mm %3477, %3474 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %3478, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_4057 = torch.constant.int 4 - %int1024_4058 = torch.constant.int 1024 - %3479 = torch.prim.ListConstruct %int4_4057, %306, %int1024_4058 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3480 = torch.aten.view %3478, %3479 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %3480, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_4059 = torch.constant.int 4 - %int32_4060 = torch.constant.int 32 - %int128_4061 = torch.constant.int 128 - %3481 = torch.prim.ListConstruct %int4_4059, %306, %int32_4060, %int128_4061 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3482 = torch.aten.view %3466, %3481 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3482, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_4062 = torch.constant.int 4 - %int8_4063 = torch.constant.int 8 - %int128_4064 = torch.constant.int 128 - %3483 = torch.prim.ListConstruct %int4_4062, %306, %int8_4063, %int128_4064 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3484 = torch.aten.view %3473, %3483 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3484, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_4065 = torch.constant.int 4 + %3448 = torch.aten.add.Scalar %3447, %int0_4037, %int1_4038 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3448, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %3449 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %3450 = torch.aten.view %3448, %3449 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %3450, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_4039 = torch.constant.int 4 + %int32_4040 = torch.constant.int 32 + %int8_4041 = torch.constant.int 8 + %int128_4042 = torch.constant.int 128 + %3451 = torch.prim.ListConstruct %int4_4039, %296, %int32_4040, %int8_4041, %int128_4042 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3452 = torch.aten.view %3444, %3451 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3452, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_4043 = torch.constant.int 32 + %int8_4044 = torch.constant.int 8 + %int128_4045 = torch.constant.int 128 + %3453 = torch.prim.ListConstruct %504, %int32_4043, %int8_4044, %int128_4045 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3454 = torch.aten.view %3452, %3453 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %3454, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_4046 = torch.constant.int 1 + %int2_4047 = torch.constant.int 2 + %3455 = torch.aten.transpose.int %3454, %int1_4046, %int2_4047 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3455, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_4048 = torch.constant.int 5 + %3456 = torch.prims.convert_element_type %3455, %int5_4048 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3456, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_4049 = torch.constant.int 32 + %int2_4050 = torch.constant.int 2 + %int8_4051 = torch.constant.int 8 + %int32_4052 = torch.constant.int 32 + %int128_4053 = torch.constant.int 128 + %3457 = torch.prim.ListConstruct %297, %int32_4049, %int2_4050, %int8_4051, %int32_4052, %int128_4053 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3458 = torch.aten.view %3220, %3457 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3458, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_4054 = torch.constant.int 8 + %int32_4055 = torch.constant.int 32 + %int128_4056 = torch.constant.int 128 + %3459 = torch.prim.ListConstruct %497, %int8_4054, %int32_4055, %int128_4056 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3460 = torch.aten.view %3458, %3459 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3460, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %3461 = torch.prim.ListConstruct %3450 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_4057 = torch.constant.bool false + %3462 = torch.aten.index_put %3460, %3461, %3456, %false_4057 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3462, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_4058 = torch.constant.int 32 + %int2_4059 = torch.constant.int 2 + %int8_4060 = torch.constant.int 8 + %int32_4061 = torch.constant.int 32 + %int128_4062 = torch.constant.int 128 + %3463 = torch.prim.ListConstruct %297, %int32_4058, %int2_4059, %int8_4060, %int32_4061, %int128_4062 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3464 = torch.aten.view %3462, %3463 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3464, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_4063 = torch.constant.int 2097152 + %3465 = torch.prim.ListConstruct %297, %int2097152_4063 : (!torch.int, !torch.int) -> !torch.list + %3466 = torch.aten.view %3464, %3465 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %3466, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_4064 = torch.constant.int 32 + %int2_4065 = torch.constant.int 2 %int8_4066 = torch.constant.int 8 - %int128_4067 = torch.constant.int 128 - %3485 = torch.prim.ListConstruct %int4_4065, %306, %int8_4066, %int128_4067 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3486 = torch.aten.view %3480, %3485 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3486, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_4068 = torch.constant.int 131072 - %none_4069 = torch.constant.none - %none_4070 = torch.constant.none - %cpu_4071 = torch.constant.device "cpu" - %false_4072 = torch.constant.bool false - %3487 = torch.aten.arange %int131072_4068, %none_4069, %none_4070, %cpu_4071, %false_4072 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_4073 = torch.constant.int 0 - %int128_4074 = torch.constant.int 128 - %none_4075 = torch.constant.none - %none_4076 = torch.constant.none - %cpu_4077 = torch.constant.device "cpu" - %false_4078 = torch.constant.bool false - %3488 = torch.aten.arange.start %int0_4073, %int128_4074, %none_4075, %none_4076, %cpu_4077, %false_4078 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_4079 = torch.constant.int 2 - %3489 = torch.aten.floor_divide.Scalar %3488, %int2_4079 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_4080 = torch.constant.int 6 - %3490 = torch.prims.convert_element_type %3489, %int6_4080 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> + %int32_4067 = torch.constant.int 32 + %int128_4068 = torch.constant.int 128 + %3467 = torch.prim.ListConstruct %297, %int32_4064, %int2_4065, %int8_4066, %int32_4067, %int128_4068 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3468 = torch.aten.view %3466, %3467 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3468, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_4069 = torch.constant.int 8 + %int32_4070 = torch.constant.int 32 + %int128_4071 = torch.constant.int 128 + %3469 = torch.prim.ListConstruct %497, %int8_4069, %int32_4070, %int128_4071 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3470 = torch.aten.view %3468, %3469 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3470, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_4072 = torch.constant.int 32 + %3471 = torch.aten.mul.Scalar %arg2, %int32_4072 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3471, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int11_4073 = torch.constant.int 11 + %int1_4074 = torch.constant.int 1 + %3472 = torch.aten.add.Scalar %3471, %int11_4073, %int1_4074 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3472, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_4075 = torch.constant.int 2 + %3473 = torch.aten.mul.Scalar %3472, %int2_4075 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3473, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_4076 = torch.constant.int 1 + %int1_4077 = torch.constant.int 1 + %3474 = torch.aten.add.Scalar %3473, %int1_4076, %int1_4077 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3474, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %3475 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %3476 = torch.aten.view %3474, %3475 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %3476, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_4078 = torch.constant.int 4 + %int32_4079 = torch.constant.int 32 + %int8_4080 = torch.constant.int 8 %int128_4081 = torch.constant.int 128 - %3491 = torch.aten.div.Scalar %3490, %int128_4081 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_4082 = torch.constant.float 2.000000e+00 - %3492 = torch.aten.mul.Scalar %3491, %float2.000000e00_4082 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_4083 = torch.constant.float 5.000000e+05 - %3493 = torch.aten.pow.Scalar %float5.000000e05_4083, %3492 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %3494 = torch.aten.reciprocal %3493 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_4084 = torch.constant.float 1.000000e+00 - %3495 = torch.aten.mul.Scalar %3494, %float1.000000e00_4084 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %3477 = torch.prim.ListConstruct %int4_4078, %296, %int32_4079, %int8_4080, %int128_4081 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3478 = torch.aten.view %3318, %3477 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3478, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_4082 = torch.constant.int 32 + %int8_4083 = torch.constant.int 8 + %int128_4084 = torch.constant.int 128 + %3479 = torch.prim.ListConstruct %504, %int32_4082, %int8_4083, %int128_4084 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3480 = torch.aten.view %3478, %3479 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %3480, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> %int1_4085 = torch.constant.int 1 - %3496 = torch.aten.unsqueeze %3487, %int1_4085 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_4086 = torch.constant.int 0 - %3497 = torch.aten.unsqueeze %3495, %int0_4086 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %3498 = torch.aten.mul.Tensor %3496, %3497 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_4087 = torch.constant.int 1 - %3499 = torch.aten.size.int %3466, %int1_4087 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_4088 = torch.constant.int 0 - %3500 = torch.aten.add.int %int0_4088, %3499 : !torch.int, !torch.int -> !torch.int - %int0_4089 = torch.constant.int 0 - %int0_4090 = torch.constant.int 0 - %int1_4091 = torch.constant.int 1 - %3501 = torch.aten.slice.Tensor %3498, %int0_4089, %int0_4090, %3500, %int1_4091 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3501, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4092 = torch.constant.int 1 - %int0_4093 = torch.constant.int 0 - %int9223372036854775807_4094 = torch.constant.int 9223372036854775807 - %int1_4095 = torch.constant.int 1 - %3502 = torch.aten.slice.Tensor %3501, %int1_4092, %int0_4093, %int9223372036854775807_4094, %int1_4095 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3502, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4096 = torch.constant.int 1 - %int0_4097 = torch.constant.int 0 - %int9223372036854775807_4098 = torch.constant.int 9223372036854775807 - %int1_4099 = torch.constant.int 1 - %3503 = torch.aten.slice.Tensor %3502, %int1_4096, %int0_4097, %int9223372036854775807_4098, %int1_4099 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3503, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_4100 = torch.constant.int 0 - %3504 = torch.aten.unsqueeze %3503, %int0_4100 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3504, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_4101 = torch.constant.int 1 - %int0_4102 = torch.constant.int 0 - %int9223372036854775807_4103 = torch.constant.int 9223372036854775807 - %int1_4104 = torch.constant.int 1 - %3505 = torch.aten.slice.Tensor %3504, %int1_4101, %int0_4102, %int9223372036854775807_4103, %int1_4104 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3505, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_4105 = torch.constant.int 2 - %int0_4106 = torch.constant.int 0 - %int9223372036854775807_4107 = torch.constant.int 9223372036854775807 - %int1_4108 = torch.constant.int 1 - %3506 = torch.aten.slice.Tensor %3505, %int2_4105, %int0_4106, %int9223372036854775807_4107, %int1_4108 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3506, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_4109 = torch.constant.int 4 - %int1_4110 = torch.constant.int 1 - %int1_4111 = torch.constant.int 1 - %3507 = torch.prim.ListConstruct %int4_4109, %int1_4110, %int1_4111 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3508 = torch.aten.repeat %3506, %3507 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %3508, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_4112 = torch.constant.int 6 - %3509 = torch.prims.convert_element_type %3482, %int6_4112 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %3509, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %3510 = torch_c.to_builtin_tensor %3509 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %3511 = torch_c.to_builtin_tensor %3508 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %3512 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%3510, %3511) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %3513 = torch_c.from_builtin_tensor %3512 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %3513, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_4113 = torch.constant.int 5 - %3514 = torch.prims.convert_element_type %3513, %int5_4113 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3514, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_4114 = torch.constant.int 131072 - %none_4115 = torch.constant.none - %none_4116 = torch.constant.none - %cpu_4117 = torch.constant.device "cpu" - %false_4118 = torch.constant.bool false - %3515 = torch.aten.arange %int131072_4114, %none_4115, %none_4116, %cpu_4117, %false_4118 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_4119 = torch.constant.int 0 - %int128_4120 = torch.constant.int 128 - %none_4121 = torch.constant.none - %none_4122 = torch.constant.none - %cpu_4123 = torch.constant.device "cpu" - %false_4124 = torch.constant.bool false - %3516 = torch.aten.arange.start %int0_4119, %int128_4120, %none_4121, %none_4122, %cpu_4123, %false_4124 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> + %int2_4086 = torch.constant.int 2 + %3481 = torch.aten.transpose.int %3480, %int1_4085, %int2_4086 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3481, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_4087 = torch.constant.int 5 + %3482 = torch.prims.convert_element_type %3481, %int5_4087 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3482, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %3483 = torch.prim.ListConstruct %3476 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_4088 = torch.constant.bool false + %3484 = torch.aten.index_put %3470, %3483, %3482, %false_4088 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3484, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_4089 = torch.constant.int 32 + %int2_4090 = torch.constant.int 2 + %int8_4091 = torch.constant.int 8 + %int32_4092 = torch.constant.int 32 + %int128_4093 = torch.constant.int 128 + %3485 = torch.prim.ListConstruct %297, %int32_4089, %int2_4090, %int8_4091, %int32_4092, %int128_4093 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3486 = torch.aten.view %3484, %3485 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3486, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_4094 = torch.constant.int 2097152 + %3487 = torch.prim.ListConstruct %297, %int2097152_4094 : (!torch.int, !torch.int) -> !torch.list + %3488 = torch.aten.view %3486, %3487 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %3488, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_4095 = torch.constant.int -2 + %3489 = torch.aten.unsqueeze %3444, %int-2_4095 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3489, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_4096 = torch.constant.int 4 + %int8_4097 = torch.constant.int 8 + %int4_4098 = torch.constant.int 4 + %int128_4099 = torch.constant.int 128 + %3490 = torch.prim.ListConstruct %int4_4096, %298, %int8_4097, %int4_4098, %int128_4099 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_4100 = torch.constant.bool false + %3491 = torch.aten.expand %3489, %3490, %false_4100 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3491, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_4101 = torch.constant.int 0 + %3492 = torch.aten.clone %3491, %int0_4101 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3492, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_4102 = torch.constant.int 4 + %int32_4103 = torch.constant.int 32 + %int128_4104 = torch.constant.int 128 + %3493 = torch.prim.ListConstruct %int4_4102, %298, %int32_4103, %int128_4104 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3494 = torch.aten._unsafe_view %3492, %3493 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3494, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_4105 = torch.constant.int -2 + %3495 = torch.aten.unsqueeze %3318, %int-2_4105 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3495, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_4106 = torch.constant.int 4 + %int8_4107 = torch.constant.int 8 + %int4_4108 = torch.constant.int 4 + %int128_4109 = torch.constant.int 128 + %3496 = torch.prim.ListConstruct %int4_4106, %298, %int8_4107, %int4_4108, %int128_4109 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_4110 = torch.constant.bool false + %3497 = torch.aten.expand %3495, %3496, %false_4110 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3497, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_4111 = torch.constant.int 0 + %3498 = torch.aten.clone %3497, %int0_4111 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3498, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_4112 = torch.constant.int 4 + %int32_4113 = torch.constant.int 32 + %int128_4114 = torch.constant.int 128 + %3499 = torch.prim.ListConstruct %int4_4112, %298, %int32_4113, %int128_4114 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3500 = torch.aten._unsafe_view %3498, %3499 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3500, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_4115 = torch.constant.int 1 + %int2_4116 = torch.constant.int 2 + %3501 = torch.aten.transpose.int %3381, %int1_4115, %int2_4116 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3501, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_4117 = torch.constant.int 1 + %int2_4118 = torch.constant.int 2 + %3502 = torch.aten.transpose.int %3494, %int1_4117, %int2_4118 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3502, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_4119 = torch.constant.int 1 + %int2_4120 = torch.constant.int 2 + %3503 = torch.aten.transpose.int %3500, %int1_4119, %int2_4120 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3503, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_4121 = torch.constant.float 0.000000e+00 + %false_4122 = torch.constant.bool false + %none_4123 = torch.constant.none + %3504:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3501, %3502, %3503, %float0.000000e00_4121, %false_4122, %327, %none_4123) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %3504#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_4124 = torch.constant.int 1 %int2_4125 = torch.constant.int 2 - %3517 = torch.aten.floor_divide.Scalar %3516, %int2_4125 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_4126 = torch.constant.int 6 - %3518 = torch.prims.convert_element_type %3517, %int6_4126 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_4127 = torch.constant.int 128 - %3519 = torch.aten.div.Scalar %3518, %int128_4127 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_4128 = torch.constant.float 2.000000e+00 - %3520 = torch.aten.mul.Scalar %3519, %float2.000000e00_4128 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_4129 = torch.constant.float 5.000000e+05 - %3521 = torch.aten.pow.Scalar %float5.000000e05_4129, %3520 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %3522 = torch.aten.reciprocal %3521 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_4130 = torch.constant.float 1.000000e+00 - %3523 = torch.aten.mul.Scalar %3522, %float1.000000e00_4130 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_4131 = torch.constant.int 1 - %3524 = torch.aten.unsqueeze %3515, %int1_4131 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_4132 = torch.constant.int 0 - %3525 = torch.aten.unsqueeze %3523, %int0_4132 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %3526 = torch.aten.mul.Tensor %3524, %3525 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_4133 = torch.constant.int 1 - %3527 = torch.aten.size.int %3473, %int1_4133 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_4134 = torch.constant.int 0 - %3528 = torch.aten.add.int %int0_4134, %3527 : !torch.int, !torch.int -> !torch.int - %int0_4135 = torch.constant.int 0 - %int0_4136 = torch.constant.int 0 - %int1_4137 = torch.constant.int 1 - %3529 = torch.aten.slice.Tensor %3526, %int0_4135, %int0_4136, %3528, %int1_4137 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3529, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4138 = torch.constant.int 1 - %int0_4139 = torch.constant.int 0 - %int9223372036854775807_4140 = torch.constant.int 9223372036854775807 + %3505 = torch.aten.transpose.int %3504#0, %int1_4124, %int2_4125 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3505, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_4126 = torch.constant.int 4 + %int4096_4127 = torch.constant.int 4096 + %3506 = torch.prim.ListConstruct %int4_4126, %298, %int4096_4127 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3507 = torch.aten.view %3505, %3506 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3507, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_4128 = torch.constant.int -2 + %int-1_4129 = torch.constant.int -1 + %3508 = torch.aten.transpose.int %105, %int-2_4128, %int-1_4129 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_4130 = torch.constant.int 5 + %3509 = torch.prims.convert_element_type %3508, %int5_4130 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_4131 = torch.constant.int 4096 + %3510 = torch.prim.ListConstruct %342, %int4096_4131 : (!torch.int, !torch.int) -> !torch.list + %3511 = torch.aten.view %3507, %3510 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3511, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3512 = torch.aten.mm %3511, %3509 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3512, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_4132 = torch.constant.int 4 + %int4096_4133 = torch.constant.int 4096 + %3513 = torch.prim.ListConstruct %int4_4132, %298, %int4096_4133 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3514 = torch.aten.view %3512, %3513 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3514, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_4134 = torch.constant.int 1 + %3515 = torch.aten.add.Tensor %3281, %3514, %int1_4134 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3515, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_4135 = torch.constant.int 6 + %3516 = torch.prims.convert_element_type %3515, %int6_4135 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3516, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_4136 = torch.constant.int 2 + %3517 = torch.aten.pow.Tensor_Scalar %3516, %int2_4136 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3517, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_4137 = torch.constant.int -1 + %3518 = torch.prim.ListConstruct %int-1_4137 : (!torch.int) -> !torch.list + %true_4138 = torch.constant.bool true + %none_4139 = torch.constant.none + %3519 = torch.aten.mean.dim %3517, %3518, %true_4138, %none_4139 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3519, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_4140 = torch.constant.float 9.9999997473787516E-6 %int1_4141 = torch.constant.int 1 - %3530 = torch.aten.slice.Tensor %3529, %int1_4138, %int0_4139, %int9223372036854775807_4140, %int1_4141 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3530, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4142 = torch.constant.int 1 - %int0_4143 = torch.constant.int 0 - %int9223372036854775807_4144 = torch.constant.int 9223372036854775807 - %int1_4145 = torch.constant.int 1 - %3531 = torch.aten.slice.Tensor %3530, %int1_4142, %int0_4143, %int9223372036854775807_4144, %int1_4145 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3531, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_4146 = torch.constant.int 0 - %3532 = torch.aten.unsqueeze %3531, %int0_4146 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3532, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_4147 = torch.constant.int 1 - %int0_4148 = torch.constant.int 0 - %int9223372036854775807_4149 = torch.constant.int 9223372036854775807 - %int1_4150 = torch.constant.int 1 - %3533 = torch.aten.slice.Tensor %3532, %int1_4147, %int0_4148, %int9223372036854775807_4149, %int1_4150 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3533, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_4151 = torch.constant.int 2 - %int0_4152 = torch.constant.int 0 - %int9223372036854775807_4153 = torch.constant.int 9223372036854775807 - %int1_4154 = torch.constant.int 1 - %3534 = torch.aten.slice.Tensor %3533, %int2_4151, %int0_4152, %int9223372036854775807_4153, %int1_4154 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3534, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_4155 = torch.constant.int 4 - %int1_4156 = torch.constant.int 1 - %int1_4157 = torch.constant.int 1 - %3535 = torch.prim.ListConstruct %int4_4155, %int1_4156, %int1_4157 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3536 = torch.aten.repeat %3534, %3535 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %3536, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_4158 = torch.constant.int 6 - %3537 = torch.prims.convert_element_type %3484, %int6_4158 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %3537, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %3538 = torch_c.to_builtin_tensor %3537 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %3539 = torch_c.to_builtin_tensor %3536 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %3540 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%3538, %3539) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %3541 = torch_c.from_builtin_tensor %3540 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %3541, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_4159 = torch.constant.int 5 - %3542 = torch.prims.convert_element_type %3541, %int5_4159 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3542, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_4160 = torch.constant.int 64 - %3543 = torch.aten.mul.Scalar %arg2, %int64_4160 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3543, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_4161 = torch.constant.int 32 + %3520 = torch.aten.add.Scalar %3519, %float9.999990e-06_4140, %int1_4141 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3520, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %3521 = torch.aten.rsqrt %3520 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3521, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %3522 = torch.aten.mul.Tensor %3516, %3521 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3522, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_4142 = torch.constant.int 5 + %3523 = torch.prims.convert_element_type %3522, %int5_4142 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3523, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %3524 = torch.aten.mul.Tensor %106, %3523 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3524, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_4143 = torch.constant.int 5 + %3525 = torch.prims.convert_element_type %3524, %int5_4143 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3525, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_4144 = torch.constant.int -2 + %int-1_4145 = torch.constant.int -1 + %3526 = torch.aten.transpose.int %107, %int-2_4144, %int-1_4145 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_4146 = torch.constant.int 5 + %3527 = torch.prims.convert_element_type %3526, %int5_4146 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_4147 = torch.constant.int 4096 + %3528 = torch.prim.ListConstruct %342, %int4096_4147 : (!torch.int, !torch.int) -> !torch.list + %3529 = torch.aten.view %3525, %3528 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3529, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3530 = torch.aten.mm %3529, %3527 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %3530, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_4148 = torch.constant.int 4 + %int14336_4149 = torch.constant.int 14336 + %3531 = torch.prim.ListConstruct %int4_4148, %298, %int14336_4149 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3532 = torch.aten.view %3530, %3531 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %3532, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %3533 = torch.aten.silu %3532 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %3533, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_4150 = torch.constant.int -2 + %int-1_4151 = torch.constant.int -1 + %3534 = torch.aten.transpose.int %108, %int-2_4150, %int-1_4151 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_4152 = torch.constant.int 5 + %3535 = torch.prims.convert_element_type %3534, %int5_4152 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_4153 = torch.constant.int 4096 + %3536 = torch.prim.ListConstruct %342, %int4096_4153 : (!torch.int, !torch.int) -> !torch.list + %3537 = torch.aten.view %3525, %3536 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3537, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3538 = torch.aten.mm %3537, %3535 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %3538, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_4154 = torch.constant.int 4 + %int14336_4155 = torch.constant.int 14336 + %3539 = torch.prim.ListConstruct %int4_4154, %298, %int14336_4155 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3540 = torch.aten.view %3538, %3539 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %3540, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %3541 = torch.aten.mul.Tensor %3533, %3540 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %3541, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_4156 = torch.constant.int -2 + %int-1_4157 = torch.constant.int -1 + %3542 = torch.aten.transpose.int %109, %int-2_4156, %int-1_4157 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_4158 = torch.constant.int 5 + %3543 = torch.prims.convert_element_type %3542, %int5_4158 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_4159 = torch.constant.int 14336 + %3544 = torch.prim.ListConstruct %342, %int14336_4159 : (!torch.int, !torch.int) -> !torch.list + %3545 = torch.aten.view %3541, %3544 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %3545, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %3546 = torch.aten.mm %3545, %3543 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3546, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_4160 = torch.constant.int 4 + %int4096_4161 = torch.constant.int 4096 + %3547 = torch.prim.ListConstruct %int4_4160, %298, %int4096_4161 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3548 = torch.aten.view %3546, %3547 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3548, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> %int1_4162 = torch.constant.int 1 - %3544 = torch.aten.add.Scalar %3543, %int32_4161, %int1_4162 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3544, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_4163 = torch.constant.int 4 - %int32_4164 = torch.constant.int 32 - %int8_4165 = torch.constant.int 8 - %int128_4166 = torch.constant.int 128 - %3545 = torch.prim.ListConstruct %int4_4163, %398, %int32_4164, %int8_4165, %int128_4166 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3546 = torch.aten.view %3542, %3545 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3546, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_4167 = torch.constant.int 4 - %3547 = torch.aten.mul.int %int4_4167, %398 : !torch.int, !torch.int -> !torch.int - %int32_4168 = torch.constant.int 32 - %int8_4169 = torch.constant.int 8 - %int128_4170 = torch.constant.int 128 - %3548 = torch.prim.ListConstruct %3547, %int32_4168, %int8_4169, %int128_4170 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3549 = torch.aten.view %3546, %3548 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3549, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_4171 = torch.constant.int 4 - %3550 = torch.aten.mul.int %int4_4171, %398 : !torch.int, !torch.int -> !torch.int - %3551 = torch.prim.ListConstruct %3550 : (!torch.int) -> !torch.list - %3552 = torch.aten.view %3544, %3551 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3552, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_4172 = torch.constant.int 32 - %int2_4173 = torch.constant.int 2 - %int32_4174 = torch.constant.int 32 - %int8_4175 = torch.constant.int 8 - %int128_4176 = torch.constant.int 128 - %3553 = torch.prim.ListConstruct %389, %int32_4172, %int2_4173, %int32_4174, %int8_4175, %int128_4176 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3554 = torch.aten.view %3386, %3553 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3554, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4177 = torch.constant.int 32 - %3555 = torch.aten.mul.int %389, %int32_4177 : !torch.int, !torch.int -> !torch.int - %int2_4178 = torch.constant.int 2 - %3556 = torch.aten.mul.int %3555, %int2_4178 : !torch.int, !torch.int -> !torch.int - %int32_4179 = torch.constant.int 32 - %int8_4180 = torch.constant.int 8 - %int128_4181 = torch.constant.int 128 - %3557 = torch.prim.ListConstruct %3556, %int32_4179, %int8_4180, %int128_4181 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3558 = torch.aten.view %3554, %3557 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3558, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %3559 = torch.prim.ListConstruct %3552 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_4182 = torch.constant.bool false - %3560 = torch.aten.index_put %3558, %3559, %3549, %false_4182 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3560, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_4183 = torch.constant.int 32 - %int2_4184 = torch.constant.int 2 - %int32_4185 = torch.constant.int 32 - %int8_4186 = torch.constant.int 8 - %int128_4187 = torch.constant.int 128 - %3561 = torch.prim.ListConstruct %389, %int32_4183, %int2_4184, %int32_4185, %int8_4186, %int128_4187 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3562 = torch.aten.view %3560, %3561 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3562, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_4188 = torch.constant.int 2097152 - %3563 = torch.prim.ListConstruct %389, %int2097152_4188 : (!torch.int, !torch.int) -> !torch.list - %3564 = torch.aten.view %3562, %3563 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3564, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_4189 = torch.constant.int 32 - %int2_4190 = torch.constant.int 2 + %3549 = torch.aten.add.Tensor %3515, %3548, %int1_4162 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3549, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_4163 = torch.constant.int 6 + %3550 = torch.prims.convert_element_type %3549, %int6_4163 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3550, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_4164 = torch.constant.int 2 + %3551 = torch.aten.pow.Tensor_Scalar %3550, %int2_4164 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3551, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_4165 = torch.constant.int -1 + %3552 = torch.prim.ListConstruct %int-1_4165 : (!torch.int) -> !torch.list + %true_4166 = torch.constant.bool true + %none_4167 = torch.constant.none + %3553 = torch.aten.mean.dim %3551, %3552, %true_4166, %none_4167 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3553, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_4168 = torch.constant.float 9.9999997473787516E-6 + %int1_4169 = torch.constant.int 1 + %3554 = torch.aten.add.Scalar %3553, %float9.999990e-06_4168, %int1_4169 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3554, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %3555 = torch.aten.rsqrt %3554 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3555, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %3556 = torch.aten.mul.Tensor %3550, %3555 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3556, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_4170 = torch.constant.int 5 + %3557 = torch.prims.convert_element_type %3556, %int5_4170 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3557, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %3558 = torch.aten.mul.Tensor %110, %3557 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3558, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_4171 = torch.constant.int 5 + %3559 = torch.prims.convert_element_type %3558, %int5_4171 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3559, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_4172 = torch.constant.int -2 + %int-1_4173 = torch.constant.int -1 + %3560 = torch.aten.transpose.int %111, %int-2_4172, %int-1_4173 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_4174 = torch.constant.int 5 + %3561 = torch.prims.convert_element_type %3560, %int5_4174 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_4175 = torch.constant.int 4096 + %3562 = torch.prim.ListConstruct %342, %int4096_4175 : (!torch.int, !torch.int) -> !torch.list + %3563 = torch.aten.view %3559, %3562 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3563, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3564 = torch.aten.mm %3563, %3561 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3564, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_4176 = torch.constant.int 4 + %int4096_4177 = torch.constant.int 4096 + %3565 = torch.prim.ListConstruct %int4_4176, %298, %int4096_4177 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3566 = torch.aten.view %3564, %3565 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3566, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_4178 = torch.constant.int -2 + %int-1_4179 = torch.constant.int -1 + %3567 = torch.aten.transpose.int %112, %int-2_4178, %int-1_4179 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_4180 = torch.constant.int 5 + %3568 = torch.prims.convert_element_type %3567, %int5_4180 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_4181 = torch.constant.int 4096 + %3569 = torch.prim.ListConstruct %342, %int4096_4181 : (!torch.int, !torch.int) -> !torch.list + %3570 = torch.aten.view %3559, %3569 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3570, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3571 = torch.aten.mm %3570, %3568 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %3571, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_4182 = torch.constant.int 4 + %int1024_4183 = torch.constant.int 1024 + %3572 = torch.prim.ListConstruct %int4_4182, %298, %int1024_4183 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3573 = torch.aten.view %3571, %3572 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %3573, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_4184 = torch.constant.int -2 + %int-1_4185 = torch.constant.int -1 + %3574 = torch.aten.transpose.int %113, %int-2_4184, %int-1_4185 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_4186 = torch.constant.int 5 + %3575 = torch.prims.convert_element_type %3574, %int5_4186 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_4187 = torch.constant.int 4096 + %3576 = torch.prim.ListConstruct %342, %int4096_4187 : (!torch.int, !torch.int) -> !torch.list + %3577 = torch.aten.view %3559, %3576 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3577, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3578 = torch.aten.mm %3577, %3575 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %3578, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_4188 = torch.constant.int 4 + %int1024_4189 = torch.constant.int 1024 + %3579 = torch.prim.ListConstruct %int4_4188, %298, %int1024_4189 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3580 = torch.aten.view %3578, %3579 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %3580, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_4190 = torch.constant.int 4 %int32_4191 = torch.constant.int 32 - %int8_4192 = torch.constant.int 8 - %int128_4193 = torch.constant.int 128 - %3565 = torch.prim.ListConstruct %389, %int32_4189, %int2_4190, %int32_4191, %int8_4192, %int128_4193 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3566 = torch.aten.view %3564, %3565 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3566, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4194 = torch.constant.int 32 - %int8_4195 = torch.constant.int 8 - %int128_4196 = torch.constant.int 128 - %3567 = torch.prim.ListConstruct %3556, %int32_4194, %int8_4195, %int128_4196 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3568 = torch.aten.view %3566, %3567 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3568, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_4197 = torch.constant.int 4 - %int32_4198 = torch.constant.int 32 - %int8_4199 = torch.constant.int 8 - %int128_4200 = torch.constant.int 128 - %3569 = torch.prim.ListConstruct %int4_4197, %398, %int32_4198, %int8_4199, %int128_4200 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3570 = torch.aten.view %3486, %3569 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3570, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_4201 = torch.constant.int 4 - %3571 = torch.aten.mul.int %int4_4201, %398 : !torch.int, !torch.int -> !torch.int - %int32_4202 = torch.constant.int 32 - %int8_4203 = torch.constant.int 8 - %int128_4204 = torch.constant.int 128 - %3572 = torch.prim.ListConstruct %3571, %int32_4202, %int8_4203, %int128_4204 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3573 = torch.aten.view %3570, %3572 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3573, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_4205 = torch.constant.int 1 - %int1_4206 = torch.constant.int 1 - %3574 = torch.aten.add.Scalar %3544, %int1_4205, %int1_4206 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3574, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int128_4192 = torch.constant.int 128 + %3581 = torch.prim.ListConstruct %int4_4190, %298, %int32_4191, %int128_4192 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3582 = torch.aten.view %3566, %3581 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3582, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_4193 = torch.constant.int 4 + %int8_4194 = torch.constant.int 8 + %int128_4195 = torch.constant.int 128 + %3583 = torch.prim.ListConstruct %int4_4193, %298, %int8_4194, %int128_4195 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3584 = torch.aten.view %3573, %3583 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3584, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_4196 = torch.constant.int 4 + %int8_4197 = torch.constant.int 8 + %int128_4198 = torch.constant.int 128 + %3585 = torch.prim.ListConstruct %int4_4196, %298, %int8_4197, %int128_4198 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3586 = torch.aten.view %3580, %3585 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3586, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_4199 = torch.constant.int 131072 + %none_4200 = torch.constant.none + %none_4201 = torch.constant.none + %cpu_4202 = torch.constant.device "cpu" + %false_4203 = torch.constant.bool false + %3587 = torch.aten.arange %int131072_4199, %none_4200, %none_4201, %cpu_4202, %false_4203 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_4204 = torch.constant.int 0 + %int128_4205 = torch.constant.int 128 + %int2_4206 = torch.constant.int 2 %int4_4207 = torch.constant.int 4 - %3575 = torch.aten.mul.int %int4_4207, %398 : !torch.int, !torch.int -> !torch.int - %3576 = torch.prim.ListConstruct %3575 : (!torch.int) -> !torch.list - %3577 = torch.aten.view %3574, %3576 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3577, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %3578 = torch.prim.ListConstruct %3577 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_4208 = torch.constant.bool false - %3579 = torch.aten.index_put %3568, %3578, %3573, %false_4208 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3579, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_4209 = torch.constant.int 32 - %int2_4210 = torch.constant.int 2 - %int32_4211 = torch.constant.int 32 - %int8_4212 = torch.constant.int 8 - %int128_4213 = torch.constant.int 128 - %3580 = torch.prim.ListConstruct %389, %int32_4209, %int2_4210, %int32_4211, %int8_4212, %int128_4213 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3581 = torch.aten.view %3579, %3580 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3581, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_4214 = torch.constant.int 2097152 - %3582 = torch.prim.ListConstruct %389, %int2097152_4214 : (!torch.int, !torch.int) -> !torch.list - %3583 = torch.aten.view %3581, %3582 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3583, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_4215 = torch.constant.int -2 - %3584 = torch.aten.unsqueeze %3542, %int-2_4215 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3584, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_4216 = torch.constant.int 4 + %none_4208 = torch.constant.none + %cpu_4209 = torch.constant.device "cpu" + %false_4210 = torch.constant.bool false + %3588 = torch.aten.arange.start_step %int0_4204, %int128_4205, %int2_4206, %int4_4207, %none_4208, %cpu_4209, %false_4210 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_4211 = torch.constant.int 6 + %3589 = torch.prims.convert_element_type %3588, %int6_4211 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_4212 = torch.constant.int 128 + %3590 = torch.aten.div.Scalar %3589, %int128_4212 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_4213 = torch.constant.float 5.000000e+05 + %3591 = torch.aten.pow.Scalar %float5.000000e05_4213, %3590 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3592 = torch.aten.reciprocal %3591 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_4214 = torch.constant.float 1.000000e+00 + %3593 = torch.aten.mul.Scalar %3592, %float1.000000e00_4214 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %3594 = torch.aten.reciprocal %3593 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_4215 = torch.constant.float 6.2831853071795862 + %3595 = torch.aten.mul.Scalar %3594, %float6.283190e00_4215 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_4216 = torch.constant.float 8.192000e+03 + %3596 = torch.aten.gt.Scalar %3595, %float8.192000e03_4216 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> %int8_4217 = torch.constant.int 8 - %int4_4218 = torch.constant.int 4 - %int128_4219 = torch.constant.int 128 - %3585 = torch.prim.ListConstruct %int4_4216, %3527, %int8_4217, %int4_4218, %int128_4219 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4220 = torch.constant.bool false - %3586 = torch.aten.expand %3584, %3585, %false_4220 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3586, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_4221 = torch.constant.int 0 - %3587 = torch.aten.clone %3586, %int0_4221 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3587, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4222 = torch.constant.int 4 - %int32_4223 = torch.constant.int 32 - %int128_4224 = torch.constant.int 128 - %3588 = torch.prim.ListConstruct %int4_4222, %3527, %int32_4223, %int128_4224 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3589 = torch.aten._unsafe_view %3587, %3588 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3589, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_4225 = torch.constant.int -2 - %3590 = torch.aten.unsqueeze %3486, %int-2_4225 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3590, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_4226 = torch.constant.int 1 - %3591 = torch.aten.size.int %3480, %int1_4226 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_4227 = torch.constant.int 4 - %int8_4228 = torch.constant.int 8 - %int4_4229 = torch.constant.int 4 - %int128_4230 = torch.constant.int 128 - %3592 = torch.prim.ListConstruct %int4_4227, %3591, %int8_4228, %int4_4229, %int128_4230 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4231 = torch.constant.bool false - %3593 = torch.aten.expand %3590, %3592, %false_4231 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3593, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %3597 = torch.aten.div.Scalar %3593, %int8_4217 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3598 = torch.aten.where.self %3596, %3597, %3593 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3599 = torch.aten.reciprocal %3595 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_4218 = torch.constant.int 8192 + %3600 = torch.aten.mul.Scalar %3599, %int8192_4218 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_4219 = torch.constant.int 1 + %int1_4220 = torch.constant.int 1 + %3601 = torch.aten.sub.Scalar %3600, %int1_4219, %int1_4220 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_4221 = torch.constant.int 3 + %3602 = torch.aten.div.Scalar %3601, %int3_4221 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_4222 = torch.constant.int 1 + %int1_4223 = torch.constant.int 1 + %3603 = torch.aten.rsub.Scalar %3602, %int1_4222, %int1_4223 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %3604 = torch.aten.mul.Tensor %3603, %3598 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_4224 = torch.constant.int 8 + %3605 = torch.aten.div.Scalar %3604, %int8_4224 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3606 = torch.aten.mul.Tensor %3602, %3598 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_4225 = torch.constant.int 1 + %3607 = torch.aten.add.Tensor %3605, %3606, %int1_4225 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_4226 = torch.constant.float 2.048000e+03 + %3608 = torch.aten.lt.Scalar %3595, %float2.048000e03_4226 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3609 = torch.aten.bitwise_not %3608 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_4227 = torch.constant.float 8.192000e+03 + %3610 = torch.aten.gt.Scalar %3595, %float8.192000e03_4227 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3611 = torch.aten.bitwise_not %3610 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3612 = torch.aten.mul.Tensor %3609, %3611 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3613 = torch.aten.where.self %3612, %3607, %3598 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3614 = torch.prim.ListConstruct %3613, %3613 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_4228 = torch.constant.int -1 + %3615 = torch.aten.cat %3614, %int-1_4228 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_4229 = torch.constant.int 6 + %3616 = torch.prims.convert_element_type %3615, %int6_4229 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_4230 = torch.constant.int 1 + %3617 = torch.aten.unsqueeze %3587, %int1_4230 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_4231 = torch.constant.int 6 + %3618 = torch.prims.convert_element_type %3617, %int6_4231 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> %int0_4232 = torch.constant.int 0 - %3594 = torch.aten.clone %3593, %int0_4232 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3594, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4233 = torch.constant.int 4 - %int32_4234 = torch.constant.int 32 - %int128_4235 = torch.constant.int 128 - %3595 = torch.prim.ListConstruct %int4_4233, %3591, %int32_4234, %int128_4235 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3596 = torch.aten._unsafe_view %3594, %3595 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3596, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_4236 = torch.constant.int 1 - %int2_4237 = torch.constant.int 2 - %3597 = torch.aten.transpose.int %3514, %int1_4236, %int2_4237 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3597, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %3619 = torch.aten.unsqueeze %3616, %int0_4232 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_4233 = torch.constant.int 6 + %3620 = torch.prims.convert_element_type %3619, %int6_4233 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %3621 = torch.aten.mul.Tensor %3618, %3620 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %3622 = torch.aten.cos %3621 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_4234 = torch.constant.int 5 + %3623 = torch.prims.convert_element_type %3622, %int5_4234 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %3624 = torch.aten.sin %3621 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_4235 = torch.constant.int 5 + %3625 = torch.prims.convert_element_type %3624, %int5_4235 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_4236 = torch.constant.int 0 + %int0_4237 = torch.constant.int 0 %int1_4238 = torch.constant.int 1 - %int2_4239 = torch.constant.int 2 - %3598 = torch.aten.transpose.int %3589, %int1_4238, %int2_4239 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3598, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_4240 = torch.constant.int 1 - %int2_4241 = torch.constant.int 2 - %3599 = torch.aten.transpose.int %3596, %int1_4240, %int2_4241 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3599, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_4242 = torch.constant.float 0.000000e+00 - %true_4243 = torch.constant.bool true - %none_4244 = torch.constant.none - %none_4245 = torch.constant.none - %3600:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3597, %3598, %3599, %float0.000000e00_4242, %true_4243, %none_4244, %none_4245) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %3600#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %3626 = torch.aten.slice.Tensor %3623, %int0_4236, %int0_4237, %298, %int1_4238 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3626, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_4239 = torch.constant.int 1 + %int0_4240 = torch.constant.int 0 + %int9223372036854775807_4241 = torch.constant.int 9223372036854775807 + %int1_4242 = torch.constant.int 1 + %3627 = torch.aten.slice.Tensor %3626, %int1_4239, %int0_4240, %int9223372036854775807_4241, %int1_4242 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3627, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_4243 = torch.constant.int 0 + %int0_4244 = torch.constant.int 0 + %int1_4245 = torch.constant.int 1 + %3628 = torch.aten.slice.Tensor %3625, %int0_4243, %int0_4244, %298, %int1_4245 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3628, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_4246 = torch.constant.int 1 - %int2_4247 = torch.constant.int 2 - %3601 = torch.aten.transpose.int %3600#0, %int1_4246, %int2_4247 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3601, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_4248 = torch.constant.int 4 - %int4096_4249 = torch.constant.int 4096 - %3602 = torch.prim.ListConstruct %int4_4248, %3499, %int4096_4249 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3603 = torch.aten.view %3601, %3602 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3603, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4250 = torch.constant.int -2 - %int-1_4251 = torch.constant.int -1 - %3604 = torch.aten.transpose.int %149, %int-2_4250, %int-1_4251 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_4252 = torch.constant.int 4 - %3605 = torch.aten.mul.int %int4_4252, %3499 : !torch.int, !torch.int -> !torch.int - %int4096_4253 = torch.constant.int 4096 - %3606 = torch.prim.ListConstruct %3605, %int4096_4253 : (!torch.int, !torch.int) -> !torch.list - %3607 = torch.aten.view %3603, %3606 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3607, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3608 = torch.aten.mm %3607, %3604 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3608, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_4254 = torch.constant.int 4 - %int4096_4255 = torch.constant.int 4096 - %3609 = torch.prim.ListConstruct %int4_4254, %3499, %int4096_4255 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3610 = torch.aten.view %3608, %3609 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3610, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_4256 = torch.constant.int 1 - %3611 = torch.aten.add.Tensor %3449, %3610, %int1_4256 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3611, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_4257 = torch.constant.int 6 - %3612 = torch.prims.convert_element_type %3611, %int6_4257 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3612, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_4258 = torch.constant.int 2 - %3613 = torch.aten.pow.Tensor_Scalar %3612, %int2_4258 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3613, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_4259 = torch.constant.int -1 - %3614 = torch.prim.ListConstruct %int-1_4259 : (!torch.int) -> !torch.list - %true_4260 = torch.constant.bool true - %none_4261 = torch.constant.none - %3615 = torch.aten.mean.dim %3613, %3614, %true_4260, %none_4261 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3615, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_4262 = torch.constant.float 9.9999997473787516E-6 + %int0_4247 = torch.constant.int 0 + %int9223372036854775807_4248 = torch.constant.int 9223372036854775807 + %int1_4249 = torch.constant.int 1 + %3629 = torch.aten.slice.Tensor %3628, %int1_4246, %int0_4247, %int9223372036854775807_4248, %int1_4249 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3629, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_4250 = torch.constant.int 0 + %3630 = torch.aten.unsqueeze %3627, %int0_4250 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3630, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_4251 = torch.constant.int 1 + %int0_4252 = torch.constant.int 0 + %int9223372036854775807_4253 = torch.constant.int 9223372036854775807 + %int1_4254 = torch.constant.int 1 + %3631 = torch.aten.slice.Tensor %3630, %int1_4251, %int0_4252, %int9223372036854775807_4253, %int1_4254 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3631, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_4255 = torch.constant.int 2 + %3632 = torch.aten.unsqueeze %3631, %int2_4255 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3632, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_4256 = torch.constant.int 3 + %int0_4257 = torch.constant.int 0 + %int9223372036854775807_4258 = torch.constant.int 9223372036854775807 + %int1_4259 = torch.constant.int 1 + %3633 = torch.aten.slice.Tensor %3632, %int3_4256, %int0_4257, %int9223372036854775807_4258, %int1_4259 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3633, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_4260 = torch.constant.int 4 + %int1_4261 = torch.constant.int 1 + %int1_4262 = torch.constant.int 1 %int1_4263 = torch.constant.int 1 - %3616 = torch.aten.add.Scalar %3615, %float9.999990e-06_4262, %int1_4263 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3616, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3617 = torch.aten.rsqrt %3616 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3617, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3618 = torch.aten.mul.Tensor %3612, %3617 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3618, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4264 = torch.constant.int 5 - %3619 = torch.prims.convert_element_type %3618, %int5_4264 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3619, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %3620 = torch.aten.mul.Tensor %150, %3619 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3620, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4265 = torch.constant.int 5 - %3621 = torch.prims.convert_element_type %3620, %int5_4265 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3621, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4266 = torch.constant.int -2 - %int-1_4267 = torch.constant.int -1 - %3622 = torch.aten.transpose.int %151, %int-2_4266, %int-1_4267 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_4268 = torch.constant.int 4 - %3623 = torch.aten.mul.int %int4_4268, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4269 = torch.constant.int 4096 - %3624 = torch.prim.ListConstruct %3623, %int4096_4269 : (!torch.int, !torch.int) -> !torch.list - %3625 = torch.aten.view %3621, %3624 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3625, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3626 = torch.aten.mm %3625, %3622 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3626, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_4270 = torch.constant.int 4 - %int14336_4271 = torch.constant.int 14336 - %3627 = torch.prim.ListConstruct %int4_4270, %306, %int14336_4271 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3628 = torch.aten.view %3626, %3627 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3628, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %3629 = torch.aten.silu %3628 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3629, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_4272 = torch.constant.int -2 - %int-1_4273 = torch.constant.int -1 - %3630 = torch.aten.transpose.int %152, %int-2_4272, %int-1_4273 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %3634 = torch.prim.ListConstruct %int4_4260, %int1_4261, %int1_4262, %int1_4263 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3635 = torch.aten.repeat %3633, %3634 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3635, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_4264 = torch.constant.int 0 + %3636 = torch.aten.unsqueeze %3629, %int0_4264 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3636, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_4265 = torch.constant.int 1 + %int0_4266 = torch.constant.int 0 + %int9223372036854775807_4267 = torch.constant.int 9223372036854775807 + %int1_4268 = torch.constant.int 1 + %3637 = torch.aten.slice.Tensor %3636, %int1_4265, %int0_4266, %int9223372036854775807_4267, %int1_4268 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3637, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_4269 = torch.constant.int 2 + %3638 = torch.aten.unsqueeze %3637, %int2_4269 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3638, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_4270 = torch.constant.int 3 + %int0_4271 = torch.constant.int 0 + %int9223372036854775807_4272 = torch.constant.int 9223372036854775807 + %int1_4273 = torch.constant.int 1 + %3639 = torch.aten.slice.Tensor %3638, %int3_4270, %int0_4271, %int9223372036854775807_4272, %int1_4273 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3639, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> %int4_4274 = torch.constant.int 4 - %3631 = torch.aten.mul.int %int4_4274, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4275 = torch.constant.int 4096 - %3632 = torch.prim.ListConstruct %3631, %int4096_4275 : (!torch.int, !torch.int) -> !torch.list - %3633 = torch.aten.view %3621, %3632 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3633, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3634 = torch.aten.mm %3633, %3630 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3634, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_4276 = torch.constant.int 4 - %int14336_4277 = torch.constant.int 14336 - %3635 = torch.prim.ListConstruct %int4_4276, %306, %int14336_4277 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3636 = torch.aten.view %3634, %3635 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3636, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %3637 = torch.aten.mul.Tensor %3629, %3636 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3637, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_4278 = torch.constant.int -2 - %int-1_4279 = torch.constant.int -1 - %3638 = torch.aten.transpose.int %153, %int-2_4278, %int-1_4279 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_4280 = torch.constant.int 1 - %3639 = torch.aten.size.int %3628, %int1_4280 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_4281 = torch.constant.int 4 - %3640 = torch.aten.mul.int %int4_4281, %3639 : !torch.int, !torch.int -> !torch.int - %int14336_4282 = torch.constant.int 14336 - %3641 = torch.prim.ListConstruct %3640, %int14336_4282 : (!torch.int, !torch.int) -> !torch.list - %3642 = torch.aten.view %3637, %3641 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3642, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %3643 = torch.aten.mm %3642, %3638 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3643, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_4283 = torch.constant.int 4 - %int4096_4284 = torch.constant.int 4096 - %3644 = torch.prim.ListConstruct %int4_4283, %3639, %int4096_4284 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3645 = torch.aten.view %3643, %3644 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3645, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_4275 = torch.constant.int 1 + %int1_4276 = torch.constant.int 1 + %int1_4277 = torch.constant.int 1 + %3640 = torch.prim.ListConstruct %int4_4274, %int1_4275, %int1_4276, %int1_4277 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3641 = torch.aten.repeat %3639, %3640 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3641, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %3642 = torch.aten.mul.Tensor %3582, %3635 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3642, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_4278 = torch.constant.int 3 + %int0_4279 = torch.constant.int 0 + %int64_4280 = torch.constant.int 64 + %int1_4281 = torch.constant.int 1 + %3643 = torch.aten.slice.Tensor %3582, %int3_4278, %int0_4279, %int64_4280, %int1_4281 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %3643, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_4282 = torch.constant.int 3 + %int64_4283 = torch.constant.int 64 + %int9223372036854775807_4284 = torch.constant.int 9223372036854775807 %int1_4285 = torch.constant.int 1 - %3646 = torch.aten.add.Tensor %3611, %3645, %int1_4285 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3646, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_4286 = torch.constant.int 6 - %3647 = torch.prims.convert_element_type %3646, %int6_4286 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3647, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_4287 = torch.constant.int 2 - %3648 = torch.aten.pow.Tensor_Scalar %3647, %int2_4287 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3648, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_4288 = torch.constant.int -1 - %3649 = torch.prim.ListConstruct %int-1_4288 : (!torch.int) -> !torch.list - %true_4289 = torch.constant.bool true + %3644 = torch.aten.slice.Tensor %3582, %int3_4282, %int64_4283, %int9223372036854775807_4284, %int1_4285 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %3644, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %3645 = torch.aten.neg %3644 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %3645, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %3646 = torch.prim.ListConstruct %3645, %3643 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_4286 = torch.constant.int -1 + %3647 = torch.aten.cat %3646, %int-1_4286 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3647, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %3648 = torch.aten.mul.Tensor %3647, %3641 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3648, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_4287 = torch.constant.int 1 + %3649 = torch.aten.add.Tensor %3642, %3648, %int1_4287 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3649, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_4288 = torch.constant.int 131072 + %none_4289 = torch.constant.none %none_4290 = torch.constant.none - %3650 = torch.aten.mean.dim %3648, %3649, %true_4289, %none_4290 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3650, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_4291 = torch.constant.float 9.9999997473787516E-6 - %int1_4292 = torch.constant.int 1 - %3651 = torch.aten.add.Scalar %3650, %float9.999990e-06_4291, %int1_4292 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3651, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3652 = torch.aten.rsqrt %3651 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3652, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3653 = torch.aten.mul.Tensor %3647, %3652 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3653, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4293 = torch.constant.int 5 - %3654 = torch.prims.convert_element_type %3653, %int5_4293 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3654, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %3655 = torch.aten.mul.Tensor %154, %3654 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3655, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4294 = torch.constant.int 5 - %3656 = torch.prims.convert_element_type %3655, %int5_4294 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3656, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4295 = torch.constant.int -2 - %int-1_4296 = torch.constant.int -1 - %3657 = torch.aten.transpose.int %155, %int-2_4295, %int-1_4296 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_4297 = torch.constant.int 4 - %3658 = torch.aten.mul.int %int4_4297, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4298 = torch.constant.int 4096 - %3659 = torch.prim.ListConstruct %3658, %int4096_4298 : (!torch.int, !torch.int) -> !torch.list - %3660 = torch.aten.view %3656, %3659 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3660, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3661 = torch.aten.mm %3660, %3657 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3661, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_4299 = torch.constant.int 4 - %int4096_4300 = torch.constant.int 4096 - %3662 = torch.prim.ListConstruct %int4_4299, %306, %int4096_4300 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3663 = torch.aten.view %3661, %3662 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3663, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4301 = torch.constant.int -2 - %int-1_4302 = torch.constant.int -1 - %3664 = torch.aten.transpose.int %156, %int-2_4301, %int-1_4302 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_4303 = torch.constant.int 4 - %3665 = torch.aten.mul.int %int4_4303, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4304 = torch.constant.int 4096 - %3666 = torch.prim.ListConstruct %3665, %int4096_4304 : (!torch.int, !torch.int) -> !torch.list - %3667 = torch.aten.view %3656, %3666 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3667, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3668 = torch.aten.mm %3667, %3664 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %3668, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_4305 = torch.constant.int 4 - %int1024_4306 = torch.constant.int 1024 - %3669 = torch.prim.ListConstruct %int4_4305, %306, %int1024_4306 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3670 = torch.aten.view %3668, %3669 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %3670, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_4307 = torch.constant.int -2 - %int-1_4308 = torch.constant.int -1 - %3671 = torch.aten.transpose.int %157, %int-2_4307, %int-1_4308 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_4309 = torch.constant.int 4 - %3672 = torch.aten.mul.int %int4_4309, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4310 = torch.constant.int 4096 - %3673 = torch.prim.ListConstruct %3672, %int4096_4310 : (!torch.int, !torch.int) -> !torch.list - %3674 = torch.aten.view %3656, %3673 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3674, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3675 = torch.aten.mm %3674, %3671 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %3675, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_4311 = torch.constant.int 4 - %int1024_4312 = torch.constant.int 1024 - %3676 = torch.prim.ListConstruct %int4_4311, %306, %int1024_4312 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3677 = torch.aten.view %3675, %3676 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %3677, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_4313 = torch.constant.int 4 - %int32_4314 = torch.constant.int 32 - %int128_4315 = torch.constant.int 128 - %3678 = torch.prim.ListConstruct %int4_4313, %306, %int32_4314, %int128_4315 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3679 = torch.aten.view %3663, %3678 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3679, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_4316 = torch.constant.int 4 - %int8_4317 = torch.constant.int 8 - %int128_4318 = torch.constant.int 128 - %3680 = torch.prim.ListConstruct %int4_4316, %306, %int8_4317, %int128_4318 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3681 = torch.aten.view %3670, %3680 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3681, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_4319 = torch.constant.int 4 - %int8_4320 = torch.constant.int 8 - %int128_4321 = torch.constant.int 128 - %3682 = torch.prim.ListConstruct %int4_4319, %306, %int8_4320, %int128_4321 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3683 = torch.aten.view %3677, %3682 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3683, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_4322 = torch.constant.int 131072 - %none_4323 = torch.constant.none - %none_4324 = torch.constant.none - %cpu_4325 = torch.constant.device "cpu" - %false_4326 = torch.constant.bool false - %3684 = torch.aten.arange %int131072_4322, %none_4323, %none_4324, %cpu_4325, %false_4326 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_4327 = torch.constant.int 0 - %int128_4328 = torch.constant.int 128 - %none_4329 = torch.constant.none - %none_4330 = torch.constant.none - %cpu_4331 = torch.constant.device "cpu" - %false_4332 = torch.constant.bool false - %3685 = torch.aten.arange.start %int0_4327, %int128_4328, %none_4329, %none_4330, %cpu_4331, %false_4332 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_4333 = torch.constant.int 2 - %3686 = torch.aten.floor_divide.Scalar %3685, %int2_4333 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_4334 = torch.constant.int 6 - %3687 = torch.prims.convert_element_type %3686, %int6_4334 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_4335 = torch.constant.int 128 - %3688 = torch.aten.div.Scalar %3687, %int128_4335 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_4336 = torch.constant.float 2.000000e+00 - %3689 = torch.aten.mul.Scalar %3688, %float2.000000e00_4336 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_4337 = torch.constant.float 5.000000e+05 - %3690 = torch.aten.pow.Scalar %float5.000000e05_4337, %3689 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %3691 = torch.aten.reciprocal %3690 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_4338 = torch.constant.float 1.000000e+00 - %3692 = torch.aten.mul.Scalar %3691, %float1.000000e00_4338 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_4339 = torch.constant.int 1 - %3693 = torch.aten.unsqueeze %3684, %int1_4339 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_4340 = torch.constant.int 0 - %3694 = torch.aten.unsqueeze %3692, %int0_4340 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %3695 = torch.aten.mul.Tensor %3693, %3694 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_4341 = torch.constant.int 1 - %3696 = torch.aten.size.int %3663, %int1_4341 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_4342 = torch.constant.int 0 - %3697 = torch.aten.add.int %int0_4342, %3696 : !torch.int, !torch.int -> !torch.int - %int0_4343 = torch.constant.int 0 - %int0_4344 = torch.constant.int 0 - %int1_4345 = torch.constant.int 1 - %3698 = torch.aten.slice.Tensor %3695, %int0_4343, %int0_4344, %3697, %int1_4345 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3698, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4346 = torch.constant.int 1 - %int0_4347 = torch.constant.int 0 - %int9223372036854775807_4348 = torch.constant.int 9223372036854775807 - %int1_4349 = torch.constant.int 1 - %3699 = torch.aten.slice.Tensor %3698, %int1_4346, %int0_4347, %int9223372036854775807_4348, %int1_4349 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3699, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %cpu_4291 = torch.constant.device "cpu" + %false_4292 = torch.constant.bool false + %3650 = torch.aten.arange %int131072_4288, %none_4289, %none_4290, %cpu_4291, %false_4292 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_4293 = torch.constant.int 0 + %int128_4294 = torch.constant.int 128 + %int2_4295 = torch.constant.int 2 + %int4_4296 = torch.constant.int 4 + %none_4297 = torch.constant.none + %cpu_4298 = torch.constant.device "cpu" + %false_4299 = torch.constant.bool false + %3651 = torch.aten.arange.start_step %int0_4293, %int128_4294, %int2_4295, %int4_4296, %none_4297, %cpu_4298, %false_4299 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_4300 = torch.constant.int 6 + %3652 = torch.prims.convert_element_type %3651, %int6_4300 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_4301 = torch.constant.int 128 + %3653 = torch.aten.div.Scalar %3652, %int128_4301 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_4302 = torch.constant.float 5.000000e+05 + %3654 = torch.aten.pow.Scalar %float5.000000e05_4302, %3653 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3655 = torch.aten.reciprocal %3654 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_4303 = torch.constant.float 1.000000e+00 + %3656 = torch.aten.mul.Scalar %3655, %float1.000000e00_4303 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %3657 = torch.aten.reciprocal %3656 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_4304 = torch.constant.float 6.2831853071795862 + %3658 = torch.aten.mul.Scalar %3657, %float6.283190e00_4304 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_4305 = torch.constant.float 8.192000e+03 + %3659 = torch.aten.gt.Scalar %3658, %float8.192000e03_4305 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_4306 = torch.constant.int 8 + %3660 = torch.aten.div.Scalar %3656, %int8_4306 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3661 = torch.aten.where.self %3659, %3660, %3656 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3662 = torch.aten.reciprocal %3658 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_4307 = torch.constant.int 8192 + %3663 = torch.aten.mul.Scalar %3662, %int8192_4307 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_4308 = torch.constant.int 1 + %int1_4309 = torch.constant.int 1 + %3664 = torch.aten.sub.Scalar %3663, %int1_4308, %int1_4309 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_4310 = torch.constant.int 3 + %3665 = torch.aten.div.Scalar %3664, %int3_4310 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_4311 = torch.constant.int 1 + %int1_4312 = torch.constant.int 1 + %3666 = torch.aten.rsub.Scalar %3665, %int1_4311, %int1_4312 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %3667 = torch.aten.mul.Tensor %3666, %3661 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_4313 = torch.constant.int 8 + %3668 = torch.aten.div.Scalar %3667, %int8_4313 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3669 = torch.aten.mul.Tensor %3665, %3661 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_4314 = torch.constant.int 1 + %3670 = torch.aten.add.Tensor %3668, %3669, %int1_4314 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_4315 = torch.constant.float 2.048000e+03 + %3671 = torch.aten.lt.Scalar %3658, %float2.048000e03_4315 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3672 = torch.aten.bitwise_not %3671 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_4316 = torch.constant.float 8.192000e+03 + %3673 = torch.aten.gt.Scalar %3658, %float8.192000e03_4316 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3674 = torch.aten.bitwise_not %3673 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3675 = torch.aten.mul.Tensor %3672, %3674 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3676 = torch.aten.where.self %3675, %3670, %3661 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3677 = torch.prim.ListConstruct %3676, %3676 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_4317 = torch.constant.int -1 + %3678 = torch.aten.cat %3677, %int-1_4317 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_4318 = torch.constant.int 6 + %3679 = torch.prims.convert_element_type %3678, %int6_4318 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_4319 = torch.constant.int 1 + %3680 = torch.aten.unsqueeze %3650, %int1_4319 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_4320 = torch.constant.int 6 + %3681 = torch.prims.convert_element_type %3680, %int6_4320 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_4321 = torch.constant.int 0 + %3682 = torch.aten.unsqueeze %3679, %int0_4321 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_4322 = torch.constant.int 6 + %3683 = torch.prims.convert_element_type %3682, %int6_4322 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %3684 = torch.aten.mul.Tensor %3681, %3683 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %3685 = torch.aten.cos %3684 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_4323 = torch.constant.int 5 + %3686 = torch.prims.convert_element_type %3685, %int5_4323 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %3687 = torch.aten.sin %3684 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_4324 = torch.constant.int 5 + %3688 = torch.prims.convert_element_type %3687, %int5_4324 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_4325 = torch.constant.int 0 + %int0_4326 = torch.constant.int 0 + %int1_4327 = torch.constant.int 1 + %3689 = torch.aten.slice.Tensor %3686, %int0_4325, %int0_4326, %298, %int1_4327 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3689, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_4328 = torch.constant.int 1 + %int0_4329 = torch.constant.int 0 + %int9223372036854775807_4330 = torch.constant.int 9223372036854775807 + %int1_4331 = torch.constant.int 1 + %3690 = torch.aten.slice.Tensor %3689, %int1_4328, %int0_4329, %int9223372036854775807_4330, %int1_4331 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3690, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_4332 = torch.constant.int 0 + %int0_4333 = torch.constant.int 0 + %int1_4334 = torch.constant.int 1 + %3691 = torch.aten.slice.Tensor %3688, %int0_4332, %int0_4333, %298, %int1_4334 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3691, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_4335 = torch.constant.int 1 + %int0_4336 = torch.constant.int 0 + %int9223372036854775807_4337 = torch.constant.int 9223372036854775807 + %int1_4338 = torch.constant.int 1 + %3692 = torch.aten.slice.Tensor %3691, %int1_4335, %int0_4336, %int9223372036854775807_4337, %int1_4338 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3692, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_4339 = torch.constant.int 0 + %3693 = torch.aten.unsqueeze %3690, %int0_4339 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3693, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_4340 = torch.constant.int 1 + %int0_4341 = torch.constant.int 0 + %int9223372036854775807_4342 = torch.constant.int 9223372036854775807 + %int1_4343 = torch.constant.int 1 + %3694 = torch.aten.slice.Tensor %3693, %int1_4340, %int0_4341, %int9223372036854775807_4342, %int1_4343 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3694, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_4344 = torch.constant.int 2 + %3695 = torch.aten.unsqueeze %3694, %int2_4344 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3695, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_4345 = torch.constant.int 3 + %int0_4346 = torch.constant.int 0 + %int9223372036854775807_4347 = torch.constant.int 9223372036854775807 + %int1_4348 = torch.constant.int 1 + %3696 = torch.aten.slice.Tensor %3695, %int3_4345, %int0_4346, %int9223372036854775807_4347, %int1_4348 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3696, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_4349 = torch.constant.int 4 %int1_4350 = torch.constant.int 1 - %int0_4351 = torch.constant.int 0 - %int9223372036854775807_4352 = torch.constant.int 9223372036854775807 - %int1_4353 = torch.constant.int 1 - %3700 = torch.aten.slice.Tensor %3699, %int1_4350, %int0_4351, %int9223372036854775807_4352, %int1_4353 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3700, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_4354 = torch.constant.int 0 - %3701 = torch.aten.unsqueeze %3700, %int0_4354 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3701, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_4355 = torch.constant.int 1 - %int0_4356 = torch.constant.int 0 - %int9223372036854775807_4357 = torch.constant.int 9223372036854775807 - %int1_4358 = torch.constant.int 1 - %3702 = torch.aten.slice.Tensor %3701, %int1_4355, %int0_4356, %int9223372036854775807_4357, %int1_4358 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3702, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_4359 = torch.constant.int 2 + %int1_4351 = torch.constant.int 1 + %int1_4352 = torch.constant.int 1 + %3697 = torch.prim.ListConstruct %int4_4349, %int1_4350, %int1_4351, %int1_4352 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3698 = torch.aten.repeat %3696, %3697 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3698, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_4353 = torch.constant.int 0 + %3699 = torch.aten.unsqueeze %3692, %int0_4353 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3699, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_4354 = torch.constant.int 1 + %int0_4355 = torch.constant.int 0 + %int9223372036854775807_4356 = torch.constant.int 9223372036854775807 + %int1_4357 = torch.constant.int 1 + %3700 = torch.aten.slice.Tensor %3699, %int1_4354, %int0_4355, %int9223372036854775807_4356, %int1_4357 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3700, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_4358 = torch.constant.int 2 + %3701 = torch.aten.unsqueeze %3700, %int2_4358 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3701, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_4359 = torch.constant.int 3 %int0_4360 = torch.constant.int 0 %int9223372036854775807_4361 = torch.constant.int 9223372036854775807 %int1_4362 = torch.constant.int 1 - %3703 = torch.aten.slice.Tensor %3702, %int2_4359, %int0_4360, %int9223372036854775807_4361, %int1_4362 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3703, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %3702 = torch.aten.slice.Tensor %3701, %int3_4359, %int0_4360, %int9223372036854775807_4361, %int1_4362 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3702, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> %int4_4363 = torch.constant.int 4 %int1_4364 = torch.constant.int 1 %int1_4365 = torch.constant.int 1 - %3704 = torch.prim.ListConstruct %int4_4363, %int1_4364, %int1_4365 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3705 = torch.aten.repeat %3703, %3704 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %3705, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_4366 = torch.constant.int 6 - %3706 = torch.prims.convert_element_type %3679, %int6_4366 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %3706, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %3707 = torch_c.to_builtin_tensor %3706 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %3708 = torch_c.to_builtin_tensor %3705 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %3709 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%3707, %3708) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %3710 = torch_c.from_builtin_tensor %3709 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %3710, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_4367 = torch.constant.int 5 - %3711 = torch.prims.convert_element_type %3710, %int5_4367 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3711, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_4368 = torch.constant.int 131072 - %none_4369 = torch.constant.none - %none_4370 = torch.constant.none - %cpu_4371 = torch.constant.device "cpu" - %false_4372 = torch.constant.bool false - %3712 = torch.aten.arange %int131072_4368, %none_4369, %none_4370, %cpu_4371, %false_4372 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_4373 = torch.constant.int 0 - %int128_4374 = torch.constant.int 128 - %none_4375 = torch.constant.none - %none_4376 = torch.constant.none - %cpu_4377 = torch.constant.device "cpu" - %false_4378 = torch.constant.bool false - %3713 = torch.aten.arange.start %int0_4373, %int128_4374, %none_4375, %none_4376, %cpu_4377, %false_4378 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> + %int1_4366 = torch.constant.int 1 + %3703 = torch.prim.ListConstruct %int4_4363, %int1_4364, %int1_4365, %int1_4366 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3704 = torch.aten.repeat %3702, %3703 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3704, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %3705 = torch.aten.mul.Tensor %3584, %3698 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3705, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_4367 = torch.constant.int 3 + %int0_4368 = torch.constant.int 0 + %int64_4369 = torch.constant.int 64 + %int1_4370 = torch.constant.int 1 + %3706 = torch.aten.slice.Tensor %3584, %int3_4367, %int0_4368, %int64_4369, %int1_4370 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %3706, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_4371 = torch.constant.int 3 + %int64_4372 = torch.constant.int 64 + %int9223372036854775807_4373 = torch.constant.int 9223372036854775807 + %int1_4374 = torch.constant.int 1 + %3707 = torch.aten.slice.Tensor %3584, %int3_4371, %int64_4372, %int9223372036854775807_4373, %int1_4374 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %3707, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %3708 = torch.aten.neg %3707 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %3708, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %3709 = torch.prim.ListConstruct %3708, %3706 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_4375 = torch.constant.int -1 + %3710 = torch.aten.cat %3709, %int-1_4375 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3710, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %3711 = torch.aten.mul.Tensor %3710, %3704 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3711, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_4376 = torch.constant.int 1 + %3712 = torch.aten.add.Tensor %3705, %3711, %int1_4376 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3712, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_4377 = torch.constant.int 32 + %3713 = torch.aten.mul.Scalar %arg2, %int32_4377 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3713, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int12 = torch.constant.int 12 + %int1_4378 = torch.constant.int 1 + %3714 = torch.aten.add.Scalar %3713, %int12, %int1_4378 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3714, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> %int2_4379 = torch.constant.int 2 - %3714 = torch.aten.floor_divide.Scalar %3713, %int2_4379 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_4380 = torch.constant.int 6 - %3715 = torch.prims.convert_element_type %3714, %int6_4380 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_4381 = torch.constant.int 128 - %3716 = torch.aten.div.Scalar %3715, %int128_4381 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_4382 = torch.constant.float 2.000000e+00 - %3717 = torch.aten.mul.Scalar %3716, %float2.000000e00_4382 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_4383 = torch.constant.float 5.000000e+05 - %3718 = torch.aten.pow.Scalar %float5.000000e05_4383, %3717 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %3719 = torch.aten.reciprocal %3718 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_4384 = torch.constant.float 1.000000e+00 - %3720 = torch.aten.mul.Scalar %3719, %float1.000000e00_4384 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_4385 = torch.constant.int 1 - %3721 = torch.aten.unsqueeze %3712, %int1_4385 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_4386 = torch.constant.int 0 - %3722 = torch.aten.unsqueeze %3720, %int0_4386 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %3723 = torch.aten.mul.Tensor %3721, %3722 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_4387 = torch.constant.int 1 - %3724 = torch.aten.size.int %3670, %int1_4387 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_4388 = torch.constant.int 0 - %3725 = torch.aten.add.int %int0_4388, %3724 : !torch.int, !torch.int -> !torch.int - %int0_4389 = torch.constant.int 0 - %int0_4390 = torch.constant.int 0 - %int1_4391 = torch.constant.int 1 - %3726 = torch.aten.slice.Tensor %3723, %int0_4389, %int0_4390, %3725, %int1_4391 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3726, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4392 = torch.constant.int 1 - %int0_4393 = torch.constant.int 0 - %int9223372036854775807_4394 = torch.constant.int 9223372036854775807 - %int1_4395 = torch.constant.int 1 - %3727 = torch.aten.slice.Tensor %3726, %int1_4392, %int0_4393, %int9223372036854775807_4394, %int1_4395 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3727, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4396 = torch.constant.int 1 - %int0_4397 = torch.constant.int 0 - %int9223372036854775807_4398 = torch.constant.int 9223372036854775807 - %int1_4399 = torch.constant.int 1 - %3728 = torch.aten.slice.Tensor %3727, %int1_4396, %int0_4397, %int9223372036854775807_4398, %int1_4399 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3728, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_4400 = torch.constant.int 0 - %3729 = torch.aten.unsqueeze %3728, %int0_4400 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3729, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_4401 = torch.constant.int 1 - %int0_4402 = torch.constant.int 0 - %int9223372036854775807_4403 = torch.constant.int 9223372036854775807 - %int1_4404 = torch.constant.int 1 - %3730 = torch.aten.slice.Tensor %3729, %int1_4401, %int0_4402, %int9223372036854775807_4403, %int1_4404 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3730, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_4405 = torch.constant.int 2 - %int0_4406 = torch.constant.int 0 - %int9223372036854775807_4407 = torch.constant.int 9223372036854775807 - %int1_4408 = torch.constant.int 1 - %3731 = torch.aten.slice.Tensor %3730, %int2_4405, %int0_4406, %int9223372036854775807_4407, %int1_4408 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3731, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_4409 = torch.constant.int 4 - %int1_4410 = torch.constant.int 1 - %int1_4411 = torch.constant.int 1 - %3732 = torch.prim.ListConstruct %int4_4409, %int1_4410, %int1_4411 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3733 = torch.aten.repeat %3731, %3732 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %3733, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_4412 = torch.constant.int 6 - %3734 = torch.prims.convert_element_type %3681, %int6_4412 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %3734, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %3735 = torch_c.to_builtin_tensor %3734 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %3736 = torch_c.to_builtin_tensor %3733 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %3737 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%3735, %3736) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %3738 = torch_c.from_builtin_tensor %3737 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %3738, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_4413 = torch.constant.int 5 - %3739 = torch.prims.convert_element_type %3738, %int5_4413 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3739, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_4414 = torch.constant.int 64 - %3740 = torch.aten.mul.Scalar %arg2, %int64_4414 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3740, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int34 = torch.constant.int 34 - %int1_4415 = torch.constant.int 1 - %3741 = torch.aten.add.Scalar %3740, %int34, %int1_4415 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3741, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_4416 = torch.constant.int 4 - %int32_4417 = torch.constant.int 32 - %int8_4418 = torch.constant.int 8 - %int128_4419 = torch.constant.int 128 - %3742 = torch.prim.ListConstruct %int4_4416, %398, %int32_4417, %int8_4418, %int128_4419 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3743 = torch.aten.view %3739, %3742 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3743, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_4420 = torch.constant.int 4 - %3744 = torch.aten.mul.int %int4_4420, %398 : !torch.int, !torch.int -> !torch.int - %int32_4421 = torch.constant.int 32 - %int8_4422 = torch.constant.int 8 - %int128_4423 = torch.constant.int 128 - %3745 = torch.prim.ListConstruct %3744, %int32_4421, %int8_4422, %int128_4423 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3746 = torch.aten.view %3743, %3745 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3746, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_4424 = torch.constant.int 4 - %3747 = torch.aten.mul.int %int4_4424, %398 : !torch.int, !torch.int -> !torch.int - %3748 = torch.prim.ListConstruct %3747 : (!torch.int) -> !torch.list - %3749 = torch.aten.view %3741, %3748 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3749, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %3715 = torch.aten.mul.Scalar %3714, %int2_4379 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3715, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_4380 = torch.constant.int 0 + %int1_4381 = torch.constant.int 1 + %3716 = torch.aten.add.Scalar %3715, %int0_4380, %int1_4381 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3716, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %3717 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %3718 = torch.aten.view %3716, %3717 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %3718, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_4382 = torch.constant.int 4 + %int32_4383 = torch.constant.int 32 + %int8_4384 = torch.constant.int 8 + %int128_4385 = torch.constant.int 128 + %3719 = torch.prim.ListConstruct %int4_4382, %296, %int32_4383, %int8_4384, %int128_4385 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3720 = torch.aten.view %3712, %3719 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3720, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_4386 = torch.constant.int 32 + %int8_4387 = torch.constant.int 8 + %int128_4388 = torch.constant.int 128 + %3721 = torch.prim.ListConstruct %504, %int32_4386, %int8_4387, %int128_4388 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3722 = torch.aten.view %3720, %3721 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %3722, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_4389 = torch.constant.int 1 + %int2_4390 = torch.constant.int 2 + %3723 = torch.aten.transpose.int %3722, %int1_4389, %int2_4390 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3723, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_4391 = torch.constant.int 5 + %3724 = torch.prims.convert_element_type %3723, %int5_4391 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3724, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_4392 = torch.constant.int 32 + %int2_4393 = torch.constant.int 2 + %int8_4394 = torch.constant.int 8 + %int32_4395 = torch.constant.int 32 + %int128_4396 = torch.constant.int 128 + %3725 = torch.prim.ListConstruct %297, %int32_4392, %int2_4393, %int8_4394, %int32_4395, %int128_4396 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3726 = torch.aten.view %3488, %3725 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3726, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_4397 = torch.constant.int 8 + %int32_4398 = torch.constant.int 32 + %int128_4399 = torch.constant.int 128 + %3727 = torch.prim.ListConstruct %497, %int8_4397, %int32_4398, %int128_4399 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3728 = torch.aten.view %3726, %3727 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3728, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %3729 = torch.prim.ListConstruct %3718 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_4400 = torch.constant.bool false + %3730 = torch.aten.index_put %3728, %3729, %3724, %false_4400 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3730, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_4401 = torch.constant.int 32 + %int2_4402 = torch.constant.int 2 + %int8_4403 = torch.constant.int 8 + %int32_4404 = torch.constant.int 32 + %int128_4405 = torch.constant.int 128 + %3731 = torch.prim.ListConstruct %297, %int32_4401, %int2_4402, %int8_4403, %int32_4404, %int128_4405 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3732 = torch.aten.view %3730, %3731 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3732, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_4406 = torch.constant.int 2097152 + %3733 = torch.prim.ListConstruct %297, %int2097152_4406 : (!torch.int, !torch.int) -> !torch.list + %3734 = torch.aten.view %3732, %3733 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %3734, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_4407 = torch.constant.int 32 + %int2_4408 = torch.constant.int 2 + %int8_4409 = torch.constant.int 8 + %int32_4410 = torch.constant.int 32 + %int128_4411 = torch.constant.int 128 + %3735 = torch.prim.ListConstruct %297, %int32_4407, %int2_4408, %int8_4409, %int32_4410, %int128_4411 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3736 = torch.aten.view %3734, %3735 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3736, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_4412 = torch.constant.int 8 + %int32_4413 = torch.constant.int 32 + %int128_4414 = torch.constant.int 128 + %3737 = torch.prim.ListConstruct %497, %int8_4412, %int32_4413, %int128_4414 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3738 = torch.aten.view %3736, %3737 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3738, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_4415 = torch.constant.int 32 + %3739 = torch.aten.mul.Scalar %arg2, %int32_4415 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3739, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int12_4416 = torch.constant.int 12 + %int1_4417 = torch.constant.int 1 + %3740 = torch.aten.add.Scalar %3739, %int12_4416, %int1_4417 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3740, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_4418 = torch.constant.int 2 + %3741 = torch.aten.mul.Scalar %3740, %int2_4418 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3741, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_4419 = torch.constant.int 1 + %int1_4420 = torch.constant.int 1 + %3742 = torch.aten.add.Scalar %3741, %int1_4419, %int1_4420 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3742, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %3743 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %3744 = torch.aten.view %3742, %3743 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %3744, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_4421 = torch.constant.int 4 + %int32_4422 = torch.constant.int 32 + %int8_4423 = torch.constant.int 8 + %int128_4424 = torch.constant.int 128 + %3745 = torch.prim.ListConstruct %int4_4421, %296, %int32_4422, %int8_4423, %int128_4424 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3746 = torch.aten.view %3586, %3745 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3746, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int32_4425 = torch.constant.int 32 - %int2_4426 = torch.constant.int 2 - %int32_4427 = torch.constant.int 32 - %int8_4428 = torch.constant.int 8 - %int128_4429 = torch.constant.int 128 - %3750 = torch.prim.ListConstruct %389, %int32_4425, %int2_4426, %int32_4427, %int8_4428, %int128_4429 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3751 = torch.aten.view %3583, %3750 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3751, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4430 = torch.constant.int 32 - %3752 = torch.aten.mul.int %389, %int32_4430 : !torch.int, !torch.int -> !torch.int - %int2_4431 = torch.constant.int 2 - %3753 = torch.aten.mul.int %3752, %int2_4431 : !torch.int, !torch.int -> !torch.int + %int8_4426 = torch.constant.int 8 + %int128_4427 = torch.constant.int 128 + %3747 = torch.prim.ListConstruct %504, %int32_4425, %int8_4426, %int128_4427 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3748 = torch.aten.view %3746, %3747 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %3748, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_4428 = torch.constant.int 1 + %int2_4429 = torch.constant.int 2 + %3749 = torch.aten.transpose.int %3748, %int1_4428, %int2_4429 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3749, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_4430 = torch.constant.int 5 + %3750 = torch.prims.convert_element_type %3749, %int5_4430 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3750, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %3751 = torch.prim.ListConstruct %3744 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_4431 = torch.constant.bool false + %3752 = torch.aten.index_put %3738, %3751, %3750, %false_4431 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3752, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> %int32_4432 = torch.constant.int 32 - %int8_4433 = torch.constant.int 8 - %int128_4434 = torch.constant.int 128 - %3754 = torch.prim.ListConstruct %3753, %int32_4432, %int8_4433, %int128_4434 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3755 = torch.aten.view %3751, %3754 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3755, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %3756 = torch.prim.ListConstruct %3749 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_4435 = torch.constant.bool false - %3757 = torch.aten.index_put %3755, %3756, %3746, %false_4435 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3757, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_4436 = torch.constant.int 32 - %int2_4437 = torch.constant.int 2 - %int32_4438 = torch.constant.int 32 - %int8_4439 = torch.constant.int 8 - %int128_4440 = torch.constant.int 128 - %3758 = torch.prim.ListConstruct %389, %int32_4436, %int2_4437, %int32_4438, %int8_4439, %int128_4440 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3759 = torch.aten.view %3757, %3758 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3759, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_4441 = torch.constant.int 2097152 - %3760 = torch.prim.ListConstruct %389, %int2097152_4441 : (!torch.int, !torch.int) -> !torch.list - %3761 = torch.aten.view %3759, %3760 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3761, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_4442 = torch.constant.int 32 - %int2_4443 = torch.constant.int 2 - %int32_4444 = torch.constant.int 32 - %int8_4445 = torch.constant.int 8 - %int128_4446 = torch.constant.int 128 - %3762 = torch.prim.ListConstruct %389, %int32_4442, %int2_4443, %int32_4444, %int8_4445, %int128_4446 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3763 = torch.aten.view %3761, %3762 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3763, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4447 = torch.constant.int 32 - %int8_4448 = torch.constant.int 8 - %int128_4449 = torch.constant.int 128 - %3764 = torch.prim.ListConstruct %3753, %int32_4447, %int8_4448, %int128_4449 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3765 = torch.aten.view %3763, %3764 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3765, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_4450 = torch.constant.int 4 - %int32_4451 = torch.constant.int 32 - %int8_4452 = torch.constant.int 8 - %int128_4453 = torch.constant.int 128 - %3766 = torch.prim.ListConstruct %int4_4450, %398, %int32_4451, %int8_4452, %int128_4453 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3767 = torch.aten.view %3683, %3766 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3767, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_4454 = torch.constant.int 4 - %3768 = torch.aten.mul.int %int4_4454, %398 : !torch.int, !torch.int -> !torch.int - %int32_4455 = torch.constant.int 32 - %int8_4456 = torch.constant.int 8 + %int2_4433 = torch.constant.int 2 + %int8_4434 = torch.constant.int 8 + %int32_4435 = torch.constant.int 32 + %int128_4436 = torch.constant.int 128 + %3753 = torch.prim.ListConstruct %297, %int32_4432, %int2_4433, %int8_4434, %int32_4435, %int128_4436 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3754 = torch.aten.view %3752, %3753 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3754, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_4437 = torch.constant.int 2097152 + %3755 = torch.prim.ListConstruct %297, %int2097152_4437 : (!torch.int, !torch.int) -> !torch.list + %3756 = torch.aten.view %3754, %3755 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %3756, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_4438 = torch.constant.int -2 + %3757 = torch.aten.unsqueeze %3712, %int-2_4438 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3757, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_4439 = torch.constant.int 4 + %int8_4440 = torch.constant.int 8 + %int4_4441 = torch.constant.int 4 + %int128_4442 = torch.constant.int 128 + %3758 = torch.prim.ListConstruct %int4_4439, %298, %int8_4440, %int4_4441, %int128_4442 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_4443 = torch.constant.bool false + %3759 = torch.aten.expand %3757, %3758, %false_4443 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3759, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_4444 = torch.constant.int 0 + %3760 = torch.aten.clone %3759, %int0_4444 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3760, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_4445 = torch.constant.int 4 + %int32_4446 = torch.constant.int 32 + %int128_4447 = torch.constant.int 128 + %3761 = torch.prim.ListConstruct %int4_4445, %298, %int32_4446, %int128_4447 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3762 = torch.aten._unsafe_view %3760, %3761 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3762, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_4448 = torch.constant.int -2 + %3763 = torch.aten.unsqueeze %3586, %int-2_4448 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3763, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_4449 = torch.constant.int 4 + %int8_4450 = torch.constant.int 8 + %int4_4451 = torch.constant.int 4 + %int128_4452 = torch.constant.int 128 + %3764 = torch.prim.ListConstruct %int4_4449, %298, %int8_4450, %int4_4451, %int128_4452 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_4453 = torch.constant.bool false + %3765 = torch.aten.expand %3763, %3764, %false_4453 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3765, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_4454 = torch.constant.int 0 + %3766 = torch.aten.clone %3765, %int0_4454 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3766, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_4455 = torch.constant.int 4 + %int32_4456 = torch.constant.int 32 %int128_4457 = torch.constant.int 128 - %3769 = torch.prim.ListConstruct %3768, %int32_4455, %int8_4456, %int128_4457 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3770 = torch.aten.view %3767, %3769 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3770, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %3767 = torch.prim.ListConstruct %int4_4455, %298, %int32_4456, %int128_4457 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3768 = torch.aten._unsafe_view %3766, %3767 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3768, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int1_4458 = torch.constant.int 1 - %int1_4459 = torch.constant.int 1 - %3771 = torch.aten.add.Scalar %3741, %int1_4458, %int1_4459 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3771, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_4460 = torch.constant.int 4 - %3772 = torch.aten.mul.int %int4_4460, %398 : !torch.int, !torch.int -> !torch.int - %3773 = torch.prim.ListConstruct %3772 : (!torch.int) -> !torch.list - %3774 = torch.aten.view %3771, %3773 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3774, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %3775 = torch.prim.ListConstruct %3774 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_4461 = torch.constant.bool false - %3776 = torch.aten.index_put %3765, %3775, %3770, %false_4461 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3776, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_4462 = torch.constant.int 32 + %int2_4459 = torch.constant.int 2 + %3769 = torch.aten.transpose.int %3649, %int1_4458, %int2_4459 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3769, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_4460 = torch.constant.int 1 + %int2_4461 = torch.constant.int 2 + %3770 = torch.aten.transpose.int %3762, %int1_4460, %int2_4461 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3770, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_4462 = torch.constant.int 1 %int2_4463 = torch.constant.int 2 - %int32_4464 = torch.constant.int 32 - %int8_4465 = torch.constant.int 8 - %int128_4466 = torch.constant.int 128 - %3777 = torch.prim.ListConstruct %389, %int32_4462, %int2_4463, %int32_4464, %int8_4465, %int128_4466 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3778 = torch.aten.view %3776, %3777 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3778, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_4467 = torch.constant.int 2097152 - %3779 = torch.prim.ListConstruct %389, %int2097152_4467 : (!torch.int, !torch.int) -> !torch.list - %3780 = torch.aten.view %3778, %3779 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3780, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_4468 = torch.constant.int -2 - %3781 = torch.aten.unsqueeze %3739, %int-2_4468 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3781, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %3771 = torch.aten.transpose.int %3768, %int1_4462, %int2_4463 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3771, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_4464 = torch.constant.float 0.000000e+00 + %false_4465 = torch.constant.bool false + %none_4466 = torch.constant.none + %3772:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3769, %3770, %3771, %float0.000000e00_4464, %false_4465, %327, %none_4466) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %3772#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_4467 = torch.constant.int 1 + %int2_4468 = torch.constant.int 2 + %3773 = torch.aten.transpose.int %3772#0, %int1_4467, %int2_4468 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3773, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int4_4469 = torch.constant.int 4 - %int8_4470 = torch.constant.int 8 - %int4_4471 = torch.constant.int 4 - %int128_4472 = torch.constant.int 128 - %3782 = torch.prim.ListConstruct %int4_4469, %3724, %int8_4470, %int4_4471, %int128_4472 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4473 = torch.constant.bool false - %3783 = torch.aten.expand %3781, %3782, %false_4473 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3783, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_4474 = torch.constant.int 0 - %3784 = torch.aten.clone %3783, %int0_4474 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3784, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4096_4470 = torch.constant.int 4096 + %3774 = torch.prim.ListConstruct %int4_4469, %298, %int4096_4470 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3775 = torch.aten.view %3773, %3774 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3775, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_4471 = torch.constant.int -2 + %int-1_4472 = torch.constant.int -1 + %3776 = torch.aten.transpose.int %114, %int-2_4471, %int-1_4472 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_4473 = torch.constant.int 5 + %3777 = torch.prims.convert_element_type %3776, %int5_4473 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_4474 = torch.constant.int 4096 + %3778 = torch.prim.ListConstruct %342, %int4096_4474 : (!torch.int, !torch.int) -> !torch.list + %3779 = torch.aten.view %3775, %3778 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3779, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3780 = torch.aten.mm %3779, %3777 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3780, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> %int4_4475 = torch.constant.int 4 - %int32_4476 = torch.constant.int 32 - %int128_4477 = torch.constant.int 128 - %3785 = torch.prim.ListConstruct %int4_4475, %3724, %int32_4476, %int128_4477 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3786 = torch.aten._unsafe_view %3784, %3785 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3786, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_4478 = torch.constant.int -2 - %3787 = torch.aten.unsqueeze %3683, %int-2_4478 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3787, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_4479 = torch.constant.int 1 - %3788 = torch.aten.size.int %3677, %int1_4479 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_4480 = torch.constant.int 4 - %int8_4481 = torch.constant.int 8 - %int4_4482 = torch.constant.int 4 - %int128_4483 = torch.constant.int 128 - %3789 = torch.prim.ListConstruct %int4_4480, %3788, %int8_4481, %int4_4482, %int128_4483 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4484 = torch.constant.bool false - %3790 = torch.aten.expand %3787, %3789, %false_4484 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3790, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_4485 = torch.constant.int 0 - %3791 = torch.aten.clone %3790, %int0_4485 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3791, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4486 = torch.constant.int 4 - %int32_4487 = torch.constant.int 32 - %int128_4488 = torch.constant.int 128 - %3792 = torch.prim.ListConstruct %int4_4486, %3788, %int32_4487, %int128_4488 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3793 = torch.aten._unsafe_view %3791, %3792 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3793, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_4489 = torch.constant.int 1 - %int2_4490 = torch.constant.int 2 - %3794 = torch.aten.transpose.int %3711, %int1_4489, %int2_4490 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3794, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_4491 = torch.constant.int 1 - %int2_4492 = torch.constant.int 2 - %3795 = torch.aten.transpose.int %3786, %int1_4491, %int2_4492 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3795, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_4493 = torch.constant.int 1 - %int2_4494 = torch.constant.int 2 - %3796 = torch.aten.transpose.int %3793, %int1_4493, %int2_4494 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3796, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_4495 = torch.constant.float 0.000000e+00 - %true_4496 = torch.constant.bool true - %none_4497 = torch.constant.none - %none_4498 = torch.constant.none - %3797:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3794, %3795, %3796, %float0.000000e00_4495, %true_4496, %none_4497, %none_4498) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %3797#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_4499 = torch.constant.int 1 - %int2_4500 = torch.constant.int 2 - %3798 = torch.aten.transpose.int %3797#0, %int1_4499, %int2_4500 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3798, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_4501 = torch.constant.int 4 - %int4096_4502 = torch.constant.int 4096 - %3799 = torch.prim.ListConstruct %int4_4501, %3696, %int4096_4502 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3800 = torch.aten.view %3798, %3799 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3800, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4503 = torch.constant.int -2 - %int-1_4504 = torch.constant.int -1 - %3801 = torch.aten.transpose.int %158, %int-2_4503, %int-1_4504 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_4505 = torch.constant.int 4 - %3802 = torch.aten.mul.int %int4_4505, %3696 : !torch.int, !torch.int -> !torch.int - %int4096_4506 = torch.constant.int 4096 - %3803 = torch.prim.ListConstruct %3802, %int4096_4506 : (!torch.int, !torch.int) -> !torch.list - %3804 = torch.aten.view %3800, %3803 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3804, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3805 = torch.aten.mm %3804, %3801 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3805, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_4507 = torch.constant.int 4 - %int4096_4508 = torch.constant.int 4096 - %3806 = torch.prim.ListConstruct %int4_4507, %3696, %int4096_4508 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3807 = torch.aten.view %3805, %3806 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3807, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_4509 = torch.constant.int 1 - %3808 = torch.aten.add.Tensor %3646, %3807, %int1_4509 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3808, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_4510 = torch.constant.int 6 - %3809 = torch.prims.convert_element_type %3808, %int6_4510 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3809, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_4511 = torch.constant.int 2 - %3810 = torch.aten.pow.Tensor_Scalar %3809, %int2_4511 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3810, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_4512 = torch.constant.int -1 - %3811 = torch.prim.ListConstruct %int-1_4512 : (!torch.int) -> !torch.list - %true_4513 = torch.constant.bool true - %none_4514 = torch.constant.none - %3812 = torch.aten.mean.dim %3810, %3811, %true_4513, %none_4514 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3812, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_4515 = torch.constant.float 9.9999997473787516E-6 - %int1_4516 = torch.constant.int 1 - %3813 = torch.aten.add.Scalar %3812, %float9.999990e-06_4515, %int1_4516 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3813, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3814 = torch.aten.rsqrt %3813 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3814, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3815 = torch.aten.mul.Tensor %3809, %3814 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3815, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int4096_4476 = torch.constant.int 4096 + %3781 = torch.prim.ListConstruct %int4_4475, %298, %int4096_4476 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3782 = torch.aten.view %3780, %3781 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3782, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_4477 = torch.constant.int 1 + %3783 = torch.aten.add.Tensor %3549, %3782, %int1_4477 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3783, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_4478 = torch.constant.int 6 + %3784 = torch.prims.convert_element_type %3783, %int6_4478 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3784, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_4479 = torch.constant.int 2 + %3785 = torch.aten.pow.Tensor_Scalar %3784, %int2_4479 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3785, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_4480 = torch.constant.int -1 + %3786 = torch.prim.ListConstruct %int-1_4480 : (!torch.int) -> !torch.list + %true_4481 = torch.constant.bool true + %none_4482 = torch.constant.none + %3787 = torch.aten.mean.dim %3785, %3786, %true_4481, %none_4482 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3787, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_4483 = torch.constant.float 9.9999997473787516E-6 + %int1_4484 = torch.constant.int 1 + %3788 = torch.aten.add.Scalar %3787, %float9.999990e-06_4483, %int1_4484 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3788, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %3789 = torch.aten.rsqrt %3788 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3789, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %3790 = torch.aten.mul.Tensor %3784, %3789 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3790, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_4485 = torch.constant.int 5 + %3791 = torch.prims.convert_element_type %3790, %int5_4485 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3791, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %3792 = torch.aten.mul.Tensor %115, %3791 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3792, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_4486 = torch.constant.int 5 + %3793 = torch.prims.convert_element_type %3792, %int5_4486 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3793, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_4487 = torch.constant.int -2 + %int-1_4488 = torch.constant.int -1 + %3794 = torch.aten.transpose.int %116, %int-2_4487, %int-1_4488 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_4489 = torch.constant.int 5 + %3795 = torch.prims.convert_element_type %3794, %int5_4489 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_4490 = torch.constant.int 4096 + %3796 = torch.prim.ListConstruct %342, %int4096_4490 : (!torch.int, !torch.int) -> !torch.list + %3797 = torch.aten.view %3793, %3796 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3797, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3798 = torch.aten.mm %3797, %3795 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %3798, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_4491 = torch.constant.int 4 + %int14336_4492 = torch.constant.int 14336 + %3799 = torch.prim.ListConstruct %int4_4491, %298, %int14336_4492 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3800 = torch.aten.view %3798, %3799 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %3800, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %3801 = torch.aten.silu %3800 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %3801, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_4493 = torch.constant.int -2 + %int-1_4494 = torch.constant.int -1 + %3802 = torch.aten.transpose.int %117, %int-2_4493, %int-1_4494 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_4495 = torch.constant.int 5 + %3803 = torch.prims.convert_element_type %3802, %int5_4495 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_4496 = torch.constant.int 4096 + %3804 = torch.prim.ListConstruct %342, %int4096_4496 : (!torch.int, !torch.int) -> !torch.list + %3805 = torch.aten.view %3793, %3804 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3805, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3806 = torch.aten.mm %3805, %3803 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %3806, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_4497 = torch.constant.int 4 + %int14336_4498 = torch.constant.int 14336 + %3807 = torch.prim.ListConstruct %int4_4497, %298, %int14336_4498 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3808 = torch.aten.view %3806, %3807 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %3808, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %3809 = torch.aten.mul.Tensor %3801, %3808 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %3809, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_4499 = torch.constant.int -2 + %int-1_4500 = torch.constant.int -1 + %3810 = torch.aten.transpose.int %118, %int-2_4499, %int-1_4500 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_4501 = torch.constant.int 5 + %3811 = torch.prims.convert_element_type %3810, %int5_4501 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_4502 = torch.constant.int 14336 + %3812 = torch.prim.ListConstruct %342, %int14336_4502 : (!torch.int, !torch.int) -> !torch.list + %3813 = torch.aten.view %3809, %3812 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %3813, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %3814 = torch.aten.mm %3813, %3811 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3814, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_4503 = torch.constant.int 4 + %int4096_4504 = torch.constant.int 4096 + %3815 = torch.prim.ListConstruct %int4_4503, %298, %int4096_4504 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3816 = torch.aten.view %3814, %3815 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3816, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_4505 = torch.constant.int 1 + %3817 = torch.aten.add.Tensor %3783, %3816, %int1_4505 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3817, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_4506 = torch.constant.int 6 + %3818 = torch.prims.convert_element_type %3817, %int6_4506 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3818, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_4507 = torch.constant.int 2 + %3819 = torch.aten.pow.Tensor_Scalar %3818, %int2_4507 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3819, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_4508 = torch.constant.int -1 + %3820 = torch.prim.ListConstruct %int-1_4508 : (!torch.int) -> !torch.list + %true_4509 = torch.constant.bool true + %none_4510 = torch.constant.none + %3821 = torch.aten.mean.dim %3819, %3820, %true_4509, %none_4510 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3821, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_4511 = torch.constant.float 9.9999997473787516E-6 + %int1_4512 = torch.constant.int 1 + %3822 = torch.aten.add.Scalar %3821, %float9.999990e-06_4511, %int1_4512 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3822, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %3823 = torch.aten.rsqrt %3822 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %3823, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %3824 = torch.aten.mul.Tensor %3818, %3823 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3824, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_4513 = torch.constant.int 5 + %3825 = torch.prims.convert_element_type %3824, %int5_4513 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3825, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %3826 = torch.aten.mul.Tensor %119, %3825 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %3826, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_4514 = torch.constant.int 5 + %3827 = torch.prims.convert_element_type %3826, %int5_4514 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3827, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_4515 = torch.constant.int -2 + %int-1_4516 = torch.constant.int -1 + %3828 = torch.aten.transpose.int %120, %int-2_4515, %int-1_4516 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> %int5_4517 = torch.constant.int 5 - %3816 = torch.prims.convert_element_type %3815, %int5_4517 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3816, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %3817 = torch.aten.mul.Tensor %159, %3816 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3817, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4518 = torch.constant.int 5 - %3818 = torch.prims.convert_element_type %3817, %int5_4518 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3818, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4519 = torch.constant.int -2 - %int-1_4520 = torch.constant.int -1 - %3819 = torch.aten.transpose.int %160, %int-2_4519, %int-1_4520 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_4521 = torch.constant.int 4 - %3820 = torch.aten.mul.int %int4_4521, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4522 = torch.constant.int 4096 - %3821 = torch.prim.ListConstruct %3820, %int4096_4522 : (!torch.int, !torch.int) -> !torch.list - %3822 = torch.aten.view %3818, %3821 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3822, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3823 = torch.aten.mm %3822, %3819 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3823, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_4523 = torch.constant.int 4 - %int14336_4524 = torch.constant.int 14336 - %3824 = torch.prim.ListConstruct %int4_4523, %306, %int14336_4524 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3825 = torch.aten.view %3823, %3824 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3825, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %3826 = torch.aten.silu %3825 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3826, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_4525 = torch.constant.int -2 - %int-1_4526 = torch.constant.int -1 - %3827 = torch.aten.transpose.int %161, %int-2_4525, %int-1_4526 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_4527 = torch.constant.int 4 - %3828 = torch.aten.mul.int %int4_4527, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4528 = torch.constant.int 4096 - %3829 = torch.prim.ListConstruct %3828, %int4096_4528 : (!torch.int, !torch.int) -> !torch.list - %3830 = torch.aten.view %3818, %3829 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3830, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3831 = torch.aten.mm %3830, %3827 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3831, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_4529 = torch.constant.int 4 - %int14336_4530 = torch.constant.int 14336 - %3832 = torch.prim.ListConstruct %int4_4529, %306, %int14336_4530 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3833 = torch.aten.view %3831, %3832 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3833, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %3834 = torch.aten.mul.Tensor %3826, %3833 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %3834, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_4531 = torch.constant.int -2 - %int-1_4532 = torch.constant.int -1 - %3835 = torch.aten.transpose.int %162, %int-2_4531, %int-1_4532 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_4533 = torch.constant.int 1 - %3836 = torch.aten.size.int %3825, %int1_4533 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_4534 = torch.constant.int 4 - %3837 = torch.aten.mul.int %int4_4534, %3836 : !torch.int, !torch.int -> !torch.int - %int14336_4535 = torch.constant.int 14336 - %3838 = torch.prim.ListConstruct %3837, %int14336_4535 : (!torch.int, !torch.int) -> !torch.list - %3839 = torch.aten.view %3834, %3838 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %3839, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %3840 = torch.aten.mm %3839, %3835 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3840, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3829 = torch.prims.convert_element_type %3828, %int5_4517 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_4518 = torch.constant.int 4096 + %3830 = torch.prim.ListConstruct %342, %int4096_4518 : (!torch.int, !torch.int) -> !torch.list + %3831 = torch.aten.view %3827, %3830 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3831, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3832 = torch.aten.mm %3831, %3829 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3832, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_4519 = torch.constant.int 4 + %int4096_4520 = torch.constant.int 4096 + %3833 = torch.prim.ListConstruct %int4_4519, %298, %int4096_4520 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3834 = torch.aten.view %3832, %3833 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %3834, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_4521 = torch.constant.int -2 + %int-1_4522 = torch.constant.int -1 + %3835 = torch.aten.transpose.int %121, %int-2_4521, %int-1_4522 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_4523 = torch.constant.int 5 + %3836 = torch.prims.convert_element_type %3835, %int5_4523 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_4524 = torch.constant.int 4096 + %3837 = torch.prim.ListConstruct %342, %int4096_4524 : (!torch.int, !torch.int) -> !torch.list + %3838 = torch.aten.view %3827, %3837 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3838, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3839 = torch.aten.mm %3838, %3836 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %3839, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_4525 = torch.constant.int 4 + %int1024_4526 = torch.constant.int 1024 + %3840 = torch.prim.ListConstruct %int4_4525, %298, %int1024_4526 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3841 = torch.aten.view %3839, %3840 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %3841, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_4527 = torch.constant.int -2 + %int-1_4528 = torch.constant.int -1 + %3842 = torch.aten.transpose.int %122, %int-2_4527, %int-1_4528 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_4529 = torch.constant.int 5 + %3843 = torch.prims.convert_element_type %3842, %int5_4529 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_4530 = torch.constant.int 4096 + %3844 = torch.prim.ListConstruct %342, %int4096_4530 : (!torch.int, !torch.int) -> !torch.list + %3845 = torch.aten.view %3827, %3844 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %3845, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %3846 = torch.aten.mm %3845, %3843 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %3846, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_4531 = torch.constant.int 4 + %int1024_4532 = torch.constant.int 1024 + %3847 = torch.prim.ListConstruct %int4_4531, %298, %int1024_4532 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3848 = torch.aten.view %3846, %3847 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %3848, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_4533 = torch.constant.int 4 + %int32_4534 = torch.constant.int 32 + %int128_4535 = torch.constant.int 128 + %3849 = torch.prim.ListConstruct %int4_4533, %298, %int32_4534, %int128_4535 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3850 = torch.aten.view %3834, %3849 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3850, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int4_4536 = torch.constant.int 4 - %int4096_4537 = torch.constant.int 4096 - %3841 = torch.prim.ListConstruct %int4_4536, %3836, %int4096_4537 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3842 = torch.aten.view %3840, %3841 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3842, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_4538 = torch.constant.int 1 - %3843 = torch.aten.add.Tensor %3808, %3842, %int1_4538 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3843, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_4539 = torch.constant.int 6 - %3844 = torch.prims.convert_element_type %3843, %int6_4539 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3844, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_4540 = torch.constant.int 2 - %3845 = torch.aten.pow.Tensor_Scalar %3844, %int2_4540 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3845, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_4541 = torch.constant.int -1 - %3846 = torch.prim.ListConstruct %int-1_4541 : (!torch.int) -> !torch.list - %true_4542 = torch.constant.bool true + %int8_4537 = torch.constant.int 8 + %int128_4538 = torch.constant.int 128 + %3851 = torch.prim.ListConstruct %int4_4536, %298, %int8_4537, %int128_4538 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3852 = torch.aten.view %3841, %3851 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3852, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_4539 = torch.constant.int 4 + %int8_4540 = torch.constant.int 8 + %int128_4541 = torch.constant.int 128 + %3853 = torch.prim.ListConstruct %int4_4539, %298, %int8_4540, %int128_4541 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3854 = torch.aten.view %3848, %3853 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3854, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_4542 = torch.constant.int 131072 %none_4543 = torch.constant.none - %3847 = torch.aten.mean.dim %3845, %3846, %true_4542, %none_4543 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3847, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_4544 = torch.constant.float 9.9999997473787516E-6 - %int1_4545 = torch.constant.int 1 - %3848 = torch.aten.add.Scalar %3847, %float9.999990e-06_4544, %int1_4545 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3848, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3849 = torch.aten.rsqrt %3848 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %3849, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %3850 = torch.aten.mul.Tensor %3844, %3849 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3850, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4546 = torch.constant.int 5 - %3851 = torch.prims.convert_element_type %3850, %int5_4546 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3851, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %3852 = torch.aten.mul.Tensor %163, %3851 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %3852, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4547 = torch.constant.int 5 - %3853 = torch.prims.convert_element_type %3852, %int5_4547 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3853, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4548 = torch.constant.int -2 - %int-1_4549 = torch.constant.int -1 - %3854 = torch.aten.transpose.int %164, %int-2_4548, %int-1_4549 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %none_4544 = torch.constant.none + %cpu_4545 = torch.constant.device "cpu" + %false_4546 = torch.constant.bool false + %3855 = torch.aten.arange %int131072_4542, %none_4543, %none_4544, %cpu_4545, %false_4546 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_4547 = torch.constant.int 0 + %int128_4548 = torch.constant.int 128 + %int2_4549 = torch.constant.int 2 %int4_4550 = torch.constant.int 4 - %3855 = torch.aten.mul.int %int4_4550, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4551 = torch.constant.int 4096 - %3856 = torch.prim.ListConstruct %3855, %int4096_4551 : (!torch.int, !torch.int) -> !torch.list - %3857 = torch.aten.view %3853, %3856 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3857, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3858 = torch.aten.mm %3857, %3854 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3858, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_4552 = torch.constant.int 4 - %int4096_4553 = torch.constant.int 4096 - %3859 = torch.prim.ListConstruct %int4_4552, %306, %int4096_4553 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3860 = torch.aten.view %3858, %3859 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3860, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4554 = torch.constant.int -2 - %int-1_4555 = torch.constant.int -1 - %3861 = torch.aten.transpose.int %165, %int-2_4554, %int-1_4555 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_4556 = torch.constant.int 4 - %3862 = torch.aten.mul.int %int4_4556, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4557 = torch.constant.int 4096 - %3863 = torch.prim.ListConstruct %3862, %int4096_4557 : (!torch.int, !torch.int) -> !torch.list - %3864 = torch.aten.view %3853, %3863 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3864, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3865 = torch.aten.mm %3864, %3861 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %3865, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_4558 = torch.constant.int 4 - %int1024_4559 = torch.constant.int 1024 - %3866 = torch.prim.ListConstruct %int4_4558, %306, %int1024_4559 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3867 = torch.aten.view %3865, %3866 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %3867, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_4560 = torch.constant.int -2 - %int-1_4561 = torch.constant.int -1 - %3868 = torch.aten.transpose.int %166, %int-2_4560, %int-1_4561 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_4562 = torch.constant.int 4 - %3869 = torch.aten.mul.int %int4_4562, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4563 = torch.constant.int 4096 - %3870 = torch.prim.ListConstruct %3869, %int4096_4563 : (!torch.int, !torch.int) -> !torch.list - %3871 = torch.aten.view %3853, %3870 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %3871, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %3872 = torch.aten.mm %3871, %3868 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %3872, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_4564 = torch.constant.int 4 - %int1024_4565 = torch.constant.int 1024 - %3873 = torch.prim.ListConstruct %int4_4564, %306, %int1024_4565 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3874 = torch.aten.view %3872, %3873 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %3874, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_4566 = torch.constant.int 4 - %int32_4567 = torch.constant.int 32 - %int128_4568 = torch.constant.int 128 - %3875 = torch.prim.ListConstruct %int4_4566, %306, %int32_4567, %int128_4568 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3876 = torch.aten.view %3860, %3875 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3876, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_4569 = torch.constant.int 4 - %int8_4570 = torch.constant.int 8 - %int128_4571 = torch.constant.int 128 - %3877 = torch.prim.ListConstruct %int4_4569, %306, %int8_4570, %int128_4571 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3878 = torch.aten.view %3867, %3877 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3878, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_4572 = torch.constant.int 4 - %int8_4573 = torch.constant.int 8 - %int128_4574 = torch.constant.int 128 - %3879 = torch.prim.ListConstruct %int4_4572, %306, %int8_4573, %int128_4574 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3880 = torch.aten.view %3874, %3879 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3880, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_4575 = torch.constant.int 131072 - %none_4576 = torch.constant.none - %none_4577 = torch.constant.none - %cpu_4578 = torch.constant.device "cpu" - %false_4579 = torch.constant.bool false - %3881 = torch.aten.arange %int131072_4575, %none_4576, %none_4577, %cpu_4578, %false_4579 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %none_4551 = torch.constant.none + %cpu_4552 = torch.constant.device "cpu" + %false_4553 = torch.constant.bool false + %3856 = torch.aten.arange.start_step %int0_4547, %int128_4548, %int2_4549, %int4_4550, %none_4551, %cpu_4552, %false_4553 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_4554 = torch.constant.int 6 + %3857 = torch.prims.convert_element_type %3856, %int6_4554 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_4555 = torch.constant.int 128 + %3858 = torch.aten.div.Scalar %3857, %int128_4555 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_4556 = torch.constant.float 5.000000e+05 + %3859 = torch.aten.pow.Scalar %float5.000000e05_4556, %3858 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3860 = torch.aten.reciprocal %3859 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_4557 = torch.constant.float 1.000000e+00 + %3861 = torch.aten.mul.Scalar %3860, %float1.000000e00_4557 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %3862 = torch.aten.reciprocal %3861 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_4558 = torch.constant.float 6.2831853071795862 + %3863 = torch.aten.mul.Scalar %3862, %float6.283190e00_4558 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_4559 = torch.constant.float 8.192000e+03 + %3864 = torch.aten.gt.Scalar %3863, %float8.192000e03_4559 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_4560 = torch.constant.int 8 + %3865 = torch.aten.div.Scalar %3861, %int8_4560 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3866 = torch.aten.where.self %3864, %3865, %3861 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3867 = torch.aten.reciprocal %3863 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_4561 = torch.constant.int 8192 + %3868 = torch.aten.mul.Scalar %3867, %int8192_4561 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_4562 = torch.constant.int 1 + %int1_4563 = torch.constant.int 1 + %3869 = torch.aten.sub.Scalar %3868, %int1_4562, %int1_4563 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_4564 = torch.constant.int 3 + %3870 = torch.aten.div.Scalar %3869, %int3_4564 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_4565 = torch.constant.int 1 + %int1_4566 = torch.constant.int 1 + %3871 = torch.aten.rsub.Scalar %3870, %int1_4565, %int1_4566 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %3872 = torch.aten.mul.Tensor %3871, %3866 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_4567 = torch.constant.int 8 + %3873 = torch.aten.div.Scalar %3872, %int8_4567 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3874 = torch.aten.mul.Tensor %3870, %3866 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_4568 = torch.constant.int 1 + %3875 = torch.aten.add.Tensor %3873, %3874, %int1_4568 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_4569 = torch.constant.float 2.048000e+03 + %3876 = torch.aten.lt.Scalar %3863, %float2.048000e03_4569 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3877 = torch.aten.bitwise_not %3876 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_4570 = torch.constant.float 8.192000e+03 + %3878 = torch.aten.gt.Scalar %3863, %float8.192000e03_4570 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3879 = torch.aten.bitwise_not %3878 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3880 = torch.aten.mul.Tensor %3877, %3879 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3881 = torch.aten.where.self %3880, %3875, %3866 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3882 = torch.prim.ListConstruct %3881, %3881 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_4571 = torch.constant.int -1 + %3883 = torch.aten.cat %3882, %int-1_4571 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_4572 = torch.constant.int 6 + %3884 = torch.prims.convert_element_type %3883, %int6_4572 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_4573 = torch.constant.int 1 + %3885 = torch.aten.unsqueeze %3855, %int1_4573 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_4574 = torch.constant.int 6 + %3886 = torch.prims.convert_element_type %3885, %int6_4574 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_4575 = torch.constant.int 0 + %3887 = torch.aten.unsqueeze %3884, %int0_4575 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_4576 = torch.constant.int 6 + %3888 = torch.prims.convert_element_type %3887, %int6_4576 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %3889 = torch.aten.mul.Tensor %3886, %3888 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %3890 = torch.aten.cos %3889 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_4577 = torch.constant.int 5 + %3891 = torch.prims.convert_element_type %3890, %int5_4577 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %3892 = torch.aten.sin %3889 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_4578 = torch.constant.int 5 + %3893 = torch.prims.convert_element_type %3892, %int5_4578 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_4579 = torch.constant.int 0 %int0_4580 = torch.constant.int 0 - %int128_4581 = torch.constant.int 128 - %none_4582 = torch.constant.none - %none_4583 = torch.constant.none - %cpu_4584 = torch.constant.device "cpu" - %false_4585 = torch.constant.bool false - %3882 = torch.aten.arange.start %int0_4580, %int128_4581, %none_4582, %none_4583, %cpu_4584, %false_4585 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_4586 = torch.constant.int 2 - %3883 = torch.aten.floor_divide.Scalar %3882, %int2_4586 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_4587 = torch.constant.int 6 - %3884 = torch.prims.convert_element_type %3883, %int6_4587 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_4588 = torch.constant.int 128 - %3885 = torch.aten.div.Scalar %3884, %int128_4588 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_4589 = torch.constant.float 2.000000e+00 - %3886 = torch.aten.mul.Scalar %3885, %float2.000000e00_4589 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_4590 = torch.constant.float 5.000000e+05 - %3887 = torch.aten.pow.Scalar %float5.000000e05_4590, %3886 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %3888 = torch.aten.reciprocal %3887 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_4591 = torch.constant.float 1.000000e+00 - %3889 = torch.aten.mul.Scalar %3888, %float1.000000e00_4591 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %int1_4581 = torch.constant.int 1 + %3894 = torch.aten.slice.Tensor %3891, %int0_4579, %int0_4580, %298, %int1_4581 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3894, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_4582 = torch.constant.int 1 + %int0_4583 = torch.constant.int 0 + %int9223372036854775807_4584 = torch.constant.int 9223372036854775807 + %int1_4585 = torch.constant.int 1 + %3895 = torch.aten.slice.Tensor %3894, %int1_4582, %int0_4583, %int9223372036854775807_4584, %int1_4585 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3895, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_4586 = torch.constant.int 0 + %int0_4587 = torch.constant.int 0 + %int1_4588 = torch.constant.int 1 + %3896 = torch.aten.slice.Tensor %3893, %int0_4586, %int0_4587, %298, %int1_4588 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3896, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_4589 = torch.constant.int 1 + %int0_4590 = torch.constant.int 0 + %int9223372036854775807_4591 = torch.constant.int 9223372036854775807 %int1_4592 = torch.constant.int 1 - %3890 = torch.aten.unsqueeze %3881, %int1_4592 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %3897 = torch.aten.slice.Tensor %3896, %int1_4589, %int0_4590, %int9223372036854775807_4591, %int1_4592 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3897, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int0_4593 = torch.constant.int 0 - %3891 = torch.aten.unsqueeze %3889, %int0_4593 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %3892 = torch.aten.mul.Tensor %3890, %3891 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %3898 = torch.aten.unsqueeze %3895, %int0_4593 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3898, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int1_4594 = torch.constant.int 1 - %3893 = torch.aten.size.int %3860, %int1_4594 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int %int0_4595 = torch.constant.int 0 - %3894 = torch.aten.add.int %int0_4595, %3893 : !torch.int, !torch.int -> !torch.int - %int0_4596 = torch.constant.int 0 - %int0_4597 = torch.constant.int 0 - %int1_4598 = torch.constant.int 1 - %3895 = torch.aten.slice.Tensor %3892, %int0_4596, %int0_4597, %3894, %int1_4598 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3895, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4599 = torch.constant.int 1 + %int9223372036854775807_4596 = torch.constant.int 9223372036854775807 + %int1_4597 = torch.constant.int 1 + %3899 = torch.aten.slice.Tensor %3898, %int1_4594, %int0_4595, %int9223372036854775807_4596, %int1_4597 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3899, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_4598 = torch.constant.int 2 + %3900 = torch.aten.unsqueeze %3899, %int2_4598 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3900, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_4599 = torch.constant.int 3 %int0_4600 = torch.constant.int 0 %int9223372036854775807_4601 = torch.constant.int 9223372036854775807 %int1_4602 = torch.constant.int 1 - %3896 = torch.aten.slice.Tensor %3895, %int1_4599, %int0_4600, %int9223372036854775807_4601, %int1_4602 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3896, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4603 = torch.constant.int 1 - %int0_4604 = torch.constant.int 0 - %int9223372036854775807_4605 = torch.constant.int 9223372036854775807 + %3901 = torch.aten.slice.Tensor %3900, %int3_4599, %int0_4600, %int9223372036854775807_4601, %int1_4602 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3901, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_4603 = torch.constant.int 4 + %int1_4604 = torch.constant.int 1 + %int1_4605 = torch.constant.int 1 %int1_4606 = torch.constant.int 1 - %3897 = torch.aten.slice.Tensor %3896, %int1_4603, %int0_4604, %int9223372036854775807_4605, %int1_4606 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3897, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %3902 = torch.prim.ListConstruct %int4_4603, %int1_4604, %int1_4605, %int1_4606 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3903 = torch.aten.repeat %3901, %3902 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3903, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> %int0_4607 = torch.constant.int 0 - %3898 = torch.aten.unsqueeze %3897, %int0_4607 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3898, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %3904 = torch.aten.unsqueeze %3897, %int0_4607 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3904, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int1_4608 = torch.constant.int 1 %int0_4609 = torch.constant.int 0 %int9223372036854775807_4610 = torch.constant.int 9223372036854775807 %int1_4611 = torch.constant.int 1 - %3899 = torch.aten.slice.Tensor %3898, %int1_4608, %int0_4609, %int9223372036854775807_4610, %int1_4611 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3899, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %3905 = torch.aten.slice.Tensor %3904, %int1_4608, %int0_4609, %int9223372036854775807_4610, %int1_4611 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3905, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int2_4612 = torch.constant.int 2 - %int0_4613 = torch.constant.int 0 - %int9223372036854775807_4614 = torch.constant.int 9223372036854775807 - %int1_4615 = torch.constant.int 1 - %3900 = torch.aten.slice.Tensor %3899, %int2_4612, %int0_4613, %int9223372036854775807_4614, %int1_4615 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3900, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_4616 = torch.constant.int 4 - %int1_4617 = torch.constant.int 1 + %3906 = torch.aten.unsqueeze %3905, %int2_4612 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3906, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_4613 = torch.constant.int 3 + %int0_4614 = torch.constant.int 0 + %int9223372036854775807_4615 = torch.constant.int 9223372036854775807 + %int1_4616 = torch.constant.int 1 + %3907 = torch.aten.slice.Tensor %3906, %int3_4613, %int0_4614, %int9223372036854775807_4615, %int1_4616 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3907, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_4617 = torch.constant.int 4 %int1_4618 = torch.constant.int 1 - %3901 = torch.prim.ListConstruct %int4_4616, %int1_4617, %int1_4618 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3902 = torch.aten.repeat %3900, %3901 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %3902, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_4619 = torch.constant.int 6 - %3903 = torch.prims.convert_element_type %3876, %int6_4619 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %3903, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %3904 = torch_c.to_builtin_tensor %3903 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %3905 = torch_c.to_builtin_tensor %3902 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %3906 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%3904, %3905) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %3907 = torch_c.from_builtin_tensor %3906 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %3907, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_4620 = torch.constant.int 5 - %3908 = torch.prims.convert_element_type %3907, %int5_4620 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3908, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_4621 = torch.constant.int 131072 - %none_4622 = torch.constant.none - %none_4623 = torch.constant.none - %cpu_4624 = torch.constant.device "cpu" - %false_4625 = torch.constant.bool false - %3909 = torch.aten.arange %int131072_4621, %none_4622, %none_4623, %cpu_4624, %false_4625 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_4626 = torch.constant.int 0 - %int128_4627 = torch.constant.int 128 - %none_4628 = torch.constant.none - %none_4629 = torch.constant.none - %cpu_4630 = torch.constant.device "cpu" - %false_4631 = torch.constant.bool false - %3910 = torch.aten.arange.start %int0_4626, %int128_4627, %none_4628, %none_4629, %cpu_4630, %false_4631 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_4632 = torch.constant.int 2 - %3911 = torch.aten.floor_divide.Scalar %3910, %int2_4632 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_4633 = torch.constant.int 6 - %3912 = torch.prims.convert_element_type %3911, %int6_4633 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_4634 = torch.constant.int 128 - %3913 = torch.aten.div.Scalar %3912, %int128_4634 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_4635 = torch.constant.float 2.000000e+00 - %3914 = torch.aten.mul.Scalar %3913, %float2.000000e00_4635 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_4636 = torch.constant.float 5.000000e+05 - %3915 = torch.aten.pow.Scalar %float5.000000e05_4636, %3914 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %3916 = torch.aten.reciprocal %3915 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_4637 = torch.constant.float 1.000000e+00 - %3917 = torch.aten.mul.Scalar %3916, %float1.000000e00_4637 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_4638 = torch.constant.int 1 - %3918 = torch.aten.unsqueeze %3909, %int1_4638 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_4639 = torch.constant.int 0 - %3919 = torch.aten.unsqueeze %3917, %int0_4639 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %3920 = torch.aten.mul.Tensor %3918, %3919 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_4640 = torch.constant.int 1 - %3921 = torch.aten.size.int %3867, %int1_4640 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_4641 = torch.constant.int 0 - %3922 = torch.aten.add.int %int0_4641, %3921 : !torch.int, !torch.int -> !torch.int - %int0_4642 = torch.constant.int 0 - %int0_4643 = torch.constant.int 0 - %int1_4644 = torch.constant.int 1 - %3923 = torch.aten.slice.Tensor %3920, %int0_4642, %int0_4643, %3922, %int1_4644 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3923, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4645 = torch.constant.int 1 - %int0_4646 = torch.constant.int 0 - %int9223372036854775807_4647 = torch.constant.int 9223372036854775807 - %int1_4648 = torch.constant.int 1 - %3924 = torch.aten.slice.Tensor %3923, %int1_4645, %int0_4646, %int9223372036854775807_4647, %int1_4648 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3924, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4649 = torch.constant.int 1 - %int0_4650 = torch.constant.int 0 - %int9223372036854775807_4651 = torch.constant.int 9223372036854775807 + %int1_4619 = torch.constant.int 1 + %int1_4620 = torch.constant.int 1 + %3908 = torch.prim.ListConstruct %int4_4617, %int1_4618, %int1_4619, %int1_4620 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3909 = torch.aten.repeat %3907, %3908 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3909, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %3910 = torch.aten.mul.Tensor %3850, %3903 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3910, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_4621 = torch.constant.int 3 + %int0_4622 = torch.constant.int 0 + %int64_4623 = torch.constant.int 64 + %int1_4624 = torch.constant.int 1 + %3911 = torch.aten.slice.Tensor %3850, %int3_4621, %int0_4622, %int64_4623, %int1_4624 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %3911, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_4625 = torch.constant.int 3 + %int64_4626 = torch.constant.int 64 + %int9223372036854775807_4627 = torch.constant.int 9223372036854775807 + %int1_4628 = torch.constant.int 1 + %3912 = torch.aten.slice.Tensor %3850, %int3_4625, %int64_4626, %int9223372036854775807_4627, %int1_4628 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %3912, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %3913 = torch.aten.neg %3912 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %3913, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %3914 = torch.prim.ListConstruct %3913, %3911 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_4629 = torch.constant.int -1 + %3915 = torch.aten.cat %3914, %int-1_4629 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3915, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %3916 = torch.aten.mul.Tensor %3915, %3909 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3916, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_4630 = torch.constant.int 1 + %3917 = torch.aten.add.Tensor %3910, %3916, %int1_4630 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3917, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_4631 = torch.constant.int 131072 + %none_4632 = torch.constant.none + %none_4633 = torch.constant.none + %cpu_4634 = torch.constant.device "cpu" + %false_4635 = torch.constant.bool false + %3918 = torch.aten.arange %int131072_4631, %none_4632, %none_4633, %cpu_4634, %false_4635 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_4636 = torch.constant.int 0 + %int128_4637 = torch.constant.int 128 + %int2_4638 = torch.constant.int 2 + %int4_4639 = torch.constant.int 4 + %none_4640 = torch.constant.none + %cpu_4641 = torch.constant.device "cpu" + %false_4642 = torch.constant.bool false + %3919 = torch.aten.arange.start_step %int0_4636, %int128_4637, %int2_4638, %int4_4639, %none_4640, %cpu_4641, %false_4642 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_4643 = torch.constant.int 6 + %3920 = torch.prims.convert_element_type %3919, %int6_4643 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_4644 = torch.constant.int 128 + %3921 = torch.aten.div.Scalar %3920, %int128_4644 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_4645 = torch.constant.float 5.000000e+05 + %3922 = torch.aten.pow.Scalar %float5.000000e05_4645, %3921 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3923 = torch.aten.reciprocal %3922 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_4646 = torch.constant.float 1.000000e+00 + %3924 = torch.aten.mul.Scalar %3923, %float1.000000e00_4646 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %3925 = torch.aten.reciprocal %3924 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_4647 = torch.constant.float 6.2831853071795862 + %3926 = torch.aten.mul.Scalar %3925, %float6.283190e00_4647 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_4648 = torch.constant.float 8.192000e+03 + %3927 = torch.aten.gt.Scalar %3926, %float8.192000e03_4648 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_4649 = torch.constant.int 8 + %3928 = torch.aten.div.Scalar %3924, %int8_4649 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3929 = torch.aten.where.self %3927, %3928, %3924 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3930 = torch.aten.reciprocal %3926 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_4650 = torch.constant.int 8192 + %3931 = torch.aten.mul.Scalar %3930, %int8192_4650 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_4651 = torch.constant.int 1 %int1_4652 = torch.constant.int 1 - %3925 = torch.aten.slice.Tensor %3924, %int1_4649, %int0_4650, %int9223372036854775807_4651, %int1_4652 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %3925, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_4653 = torch.constant.int 0 - %3926 = torch.aten.unsqueeze %3925, %int0_4653 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3926, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %3932 = torch.aten.sub.Scalar %3931, %int1_4651, %int1_4652 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_4653 = torch.constant.int 3 + %3933 = torch.aten.div.Scalar %3932, %int3_4653 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_4654 = torch.constant.int 1 - %int0_4655 = torch.constant.int 0 - %int9223372036854775807_4656 = torch.constant.int 9223372036854775807 + %int1_4655 = torch.constant.int 1 + %3934 = torch.aten.rsub.Scalar %3933, %int1_4654, %int1_4655 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %3935 = torch.aten.mul.Tensor %3934, %3929 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_4656 = torch.constant.int 8 + %3936 = torch.aten.div.Scalar %3935, %int8_4656 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %3937 = torch.aten.mul.Tensor %3933, %3929 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> %int1_4657 = torch.constant.int 1 - %3927 = torch.aten.slice.Tensor %3926, %int1_4654, %int0_4655, %int9223372036854775807_4656, %int1_4657 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3927, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_4658 = torch.constant.int 2 - %int0_4659 = torch.constant.int 0 - %int9223372036854775807_4660 = torch.constant.int 9223372036854775807 - %int1_4661 = torch.constant.int 1 - %3928 = torch.aten.slice.Tensor %3927, %int2_4658, %int0_4659, %int9223372036854775807_4660, %int1_4661 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %3928, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_4662 = torch.constant.int 4 - %int1_4663 = torch.constant.int 1 - %int1_4664 = torch.constant.int 1 - %3929 = torch.prim.ListConstruct %int4_4662, %int1_4663, %int1_4664 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3930 = torch.aten.repeat %3928, %3929 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %3930, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> + %3938 = torch.aten.add.Tensor %3936, %3937, %int1_4657 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_4658 = torch.constant.float 2.048000e+03 + %3939 = torch.aten.lt.Scalar %3926, %float2.048000e03_4658 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3940 = torch.aten.bitwise_not %3939 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_4659 = torch.constant.float 8.192000e+03 + %3941 = torch.aten.gt.Scalar %3926, %float8.192000e03_4659 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %3942 = torch.aten.bitwise_not %3941 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3943 = torch.aten.mul.Tensor %3940, %3942 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %3944 = torch.aten.where.self %3943, %3938, %3929 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %3945 = torch.prim.ListConstruct %3944, %3944 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_4660 = torch.constant.int -1 + %3946 = torch.aten.cat %3945, %int-1_4660 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_4661 = torch.constant.int 6 + %3947 = torch.prims.convert_element_type %3946, %int6_4661 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_4662 = torch.constant.int 1 + %3948 = torch.aten.unsqueeze %3918, %int1_4662 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_4663 = torch.constant.int 6 + %3949 = torch.prims.convert_element_type %3948, %int6_4663 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_4664 = torch.constant.int 0 + %3950 = torch.aten.unsqueeze %3947, %int0_4664 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> %int6_4665 = torch.constant.int 6 - %3931 = torch.prims.convert_element_type %3878, %int6_4665 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %3931, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %3932 = torch_c.to_builtin_tensor %3931 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %3933 = torch_c.to_builtin_tensor %3930 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %3934 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%3932, %3933) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %3935 = torch_c.from_builtin_tensor %3934 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %3935, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> + %3951 = torch.prims.convert_element_type %3950, %int6_4665 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %3952 = torch.aten.mul.Tensor %3949, %3951 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %3953 = torch.aten.cos %3952 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> %int5_4666 = torch.constant.int 5 - %3936 = torch.prims.convert_element_type %3935, %int5_4666 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3936, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_4667 = torch.constant.int 64 - %3937 = torch.aten.mul.Scalar %arg2, %int64_4667 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3937, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int36 = torch.constant.int 36 - %int1_4668 = torch.constant.int 1 - %3938 = torch.aten.add.Scalar %3937, %int36, %int1_4668 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3938, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_4669 = torch.constant.int 4 - %int32_4670 = torch.constant.int 32 - %int8_4671 = torch.constant.int 8 - %int128_4672 = torch.constant.int 128 - %3939 = torch.prim.ListConstruct %int4_4669, %398, %int32_4670, %int8_4671, %int128_4672 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3940 = torch.aten.view %3936, %3939 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3940, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_4673 = torch.constant.int 4 - %3941 = torch.aten.mul.int %int4_4673, %398 : !torch.int, !torch.int -> !torch.int - %int32_4674 = torch.constant.int 32 - %int8_4675 = torch.constant.int 8 - %int128_4676 = torch.constant.int 128 - %3942 = torch.prim.ListConstruct %3941, %int32_4674, %int8_4675, %int128_4676 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3943 = torch.aten.view %3940, %3942 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3943, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_4677 = torch.constant.int 4 - %3944 = torch.aten.mul.int %int4_4677, %398 : !torch.int, !torch.int -> !torch.int - %3945 = torch.prim.ListConstruct %3944 : (!torch.int) -> !torch.list - %3946 = torch.aten.view %3938, %3945 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3946, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_4678 = torch.constant.int 32 - %int2_4679 = torch.constant.int 2 - %int32_4680 = torch.constant.int 32 - %int8_4681 = torch.constant.int 8 - %int128_4682 = torch.constant.int 128 - %3947 = torch.prim.ListConstruct %389, %int32_4678, %int2_4679, %int32_4680, %int8_4681, %int128_4682 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3948 = torch.aten.view %3780, %3947 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3948, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4683 = torch.constant.int 32 - %3949 = torch.aten.mul.int %389, %int32_4683 : !torch.int, !torch.int -> !torch.int - %int2_4684 = torch.constant.int 2 - %3950 = torch.aten.mul.int %3949, %int2_4684 : !torch.int, !torch.int -> !torch.int - %int32_4685 = torch.constant.int 32 - %int8_4686 = torch.constant.int 8 - %int128_4687 = torch.constant.int 128 - %3951 = torch.prim.ListConstruct %3950, %int32_4685, %int8_4686, %int128_4687 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3952 = torch.aten.view %3948, %3951 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3952, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %3953 = torch.prim.ListConstruct %3946 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_4688 = torch.constant.bool false - %3954 = torch.aten.index_put %3952, %3953, %3943, %false_4688 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3954, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_4689 = torch.constant.int 32 - %int2_4690 = torch.constant.int 2 - %int32_4691 = torch.constant.int 32 - %int8_4692 = torch.constant.int 8 - %int128_4693 = torch.constant.int 128 - %3955 = torch.prim.ListConstruct %389, %int32_4689, %int2_4690, %int32_4691, %int8_4692, %int128_4693 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3956 = torch.aten.view %3954, %3955 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3956, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_4694 = torch.constant.int 2097152 - %3957 = torch.prim.ListConstruct %389, %int2097152_4694 : (!torch.int, !torch.int) -> !torch.list - %3958 = torch.aten.view %3956, %3957 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3958, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_4695 = torch.constant.int 32 - %int2_4696 = torch.constant.int 2 - %int32_4697 = torch.constant.int 32 - %int8_4698 = torch.constant.int 8 - %int128_4699 = torch.constant.int 128 - %3959 = torch.prim.ListConstruct %389, %int32_4695, %int2_4696, %int32_4697, %int8_4698, %int128_4699 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3960 = torch.aten.view %3958, %3959 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3960, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4700 = torch.constant.int 32 - %int8_4701 = torch.constant.int 8 - %int128_4702 = torch.constant.int 128 - %3961 = torch.prim.ListConstruct %3950, %int32_4700, %int8_4701, %int128_4702 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3962 = torch.aten.view %3960, %3961 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3962, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_4703 = torch.constant.int 4 - %int32_4704 = torch.constant.int 32 - %int8_4705 = torch.constant.int 8 - %int128_4706 = torch.constant.int 128 - %3963 = torch.prim.ListConstruct %int4_4703, %398, %int32_4704, %int8_4705, %int128_4706 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3964 = torch.aten.view %3880, %3963 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3964, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_4707 = torch.constant.int 4 - %3965 = torch.aten.mul.int %int4_4707, %398 : !torch.int, !torch.int -> !torch.int - %int32_4708 = torch.constant.int 32 - %int8_4709 = torch.constant.int 8 - %int128_4710 = torch.constant.int 128 - %3966 = torch.prim.ListConstruct %3965, %int32_4708, %int8_4709, %int128_4710 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3967 = torch.aten.view %3964, %3966 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3967, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_4711 = torch.constant.int 1 - %int1_4712 = torch.constant.int 1 - %3968 = torch.aten.add.Scalar %3938, %int1_4711, %int1_4712 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3968, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_4713 = torch.constant.int 4 - %3969 = torch.aten.mul.int %int4_4713, %398 : !torch.int, !torch.int -> !torch.int - %3970 = torch.prim.ListConstruct %3969 : (!torch.int) -> !torch.list - %3971 = torch.aten.view %3968, %3970 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3971, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %3972 = torch.prim.ListConstruct %3971 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_4714 = torch.constant.bool false - %3973 = torch.aten.index_put %3962, %3972, %3967, %false_4714 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %3973, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_4715 = torch.constant.int 32 - %int2_4716 = torch.constant.int 2 - %int32_4717 = torch.constant.int 32 - %int8_4718 = torch.constant.int 8 - %int128_4719 = torch.constant.int 128 - %3974 = torch.prim.ListConstruct %389, %int32_4715, %int2_4716, %int32_4717, %int8_4718, %int128_4719 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3975 = torch.aten.view %3973, %3974 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3975, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_4720 = torch.constant.int 2097152 - %3976 = torch.prim.ListConstruct %389, %int2097152_4720 : (!torch.int, !torch.int) -> !torch.list - %3977 = torch.aten.view %3975, %3976 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3977, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_4721 = torch.constant.int -2 - %3978 = torch.aten.unsqueeze %3936, %int-2_4721 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3978, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_4722 = torch.constant.int 4 - %int8_4723 = torch.constant.int 8 - %int4_4724 = torch.constant.int 4 - %int128_4725 = torch.constant.int 128 - %3979 = torch.prim.ListConstruct %int4_4722, %3921, %int8_4723, %int4_4724, %int128_4725 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4726 = torch.constant.bool false - %3980 = torch.aten.expand %3978, %3979, %false_4726 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3980, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_4727 = torch.constant.int 0 - %3981 = torch.aten.clone %3980, %int0_4727 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3981, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4728 = torch.constant.int 4 + %3954 = torch.prims.convert_element_type %3953, %int5_4666 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %3955 = torch.aten.sin %3952 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_4667 = torch.constant.int 5 + %3956 = torch.prims.convert_element_type %3955, %int5_4667 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_4668 = torch.constant.int 0 + %int0_4669 = torch.constant.int 0 + %int1_4670 = torch.constant.int 1 + %3957 = torch.aten.slice.Tensor %3954, %int0_4668, %int0_4669, %298, %int1_4670 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3957, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_4671 = torch.constant.int 1 + %int0_4672 = torch.constant.int 0 + %int9223372036854775807_4673 = torch.constant.int 9223372036854775807 + %int1_4674 = torch.constant.int 1 + %3958 = torch.aten.slice.Tensor %3957, %int1_4671, %int0_4672, %int9223372036854775807_4673, %int1_4674 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3958, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_4675 = torch.constant.int 0 + %int0_4676 = torch.constant.int 0 + %int1_4677 = torch.constant.int 1 + %3959 = torch.aten.slice.Tensor %3956, %int0_4675, %int0_4676, %298, %int1_4677 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3959, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_4678 = torch.constant.int 1 + %int0_4679 = torch.constant.int 0 + %int9223372036854775807_4680 = torch.constant.int 9223372036854775807 + %int1_4681 = torch.constant.int 1 + %3960 = torch.aten.slice.Tensor %3959, %int1_4678, %int0_4679, %int9223372036854775807_4680, %int1_4681 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3960, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_4682 = torch.constant.int 0 + %3961 = torch.aten.unsqueeze %3958, %int0_4682 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3961, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_4683 = torch.constant.int 1 + %int0_4684 = torch.constant.int 0 + %int9223372036854775807_4685 = torch.constant.int 9223372036854775807 + %int1_4686 = torch.constant.int 1 + %3962 = torch.aten.slice.Tensor %3961, %int1_4683, %int0_4684, %int9223372036854775807_4685, %int1_4686 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3962, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_4687 = torch.constant.int 2 + %3963 = torch.aten.unsqueeze %3962, %int2_4687 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3963, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_4688 = torch.constant.int 3 + %int0_4689 = torch.constant.int 0 + %int9223372036854775807_4690 = torch.constant.int 9223372036854775807 + %int1_4691 = torch.constant.int 1 + %3964 = torch.aten.slice.Tensor %3963, %int3_4688, %int0_4689, %int9223372036854775807_4690, %int1_4691 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3964, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_4692 = torch.constant.int 4 + %int1_4693 = torch.constant.int 1 + %int1_4694 = torch.constant.int 1 + %int1_4695 = torch.constant.int 1 + %3965 = torch.prim.ListConstruct %int4_4692, %int1_4693, %int1_4694, %int1_4695 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3966 = torch.aten.repeat %3964, %3965 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3966, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_4696 = torch.constant.int 0 + %3967 = torch.aten.unsqueeze %3960, %int0_4696 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3967, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_4697 = torch.constant.int 1 + %int0_4698 = torch.constant.int 0 + %int9223372036854775807_4699 = torch.constant.int 9223372036854775807 + %int1_4700 = torch.constant.int 1 + %3968 = torch.aten.slice.Tensor %3967, %int1_4697, %int0_4698, %int9223372036854775807_4699, %int1_4700 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %3968, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_4701 = torch.constant.int 2 + %3969 = torch.aten.unsqueeze %3968, %int2_4701 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3969, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_4702 = torch.constant.int 3 + %int0_4703 = torch.constant.int 0 + %int9223372036854775807_4704 = torch.constant.int 9223372036854775807 + %int1_4705 = torch.constant.int 1 + %3970 = torch.aten.slice.Tensor %3969, %int3_4702, %int0_4703, %int9223372036854775807_4704, %int1_4705 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %3970, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_4706 = torch.constant.int 4 + %int1_4707 = torch.constant.int 1 + %int1_4708 = torch.constant.int 1 + %int1_4709 = torch.constant.int 1 + %3971 = torch.prim.ListConstruct %int4_4706, %int1_4707, %int1_4708, %int1_4709 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3972 = torch.aten.repeat %3970, %3971 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %3972, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %3973 = torch.aten.mul.Tensor %3852, %3966 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3973, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_4710 = torch.constant.int 3 + %int0_4711 = torch.constant.int 0 + %int64_4712 = torch.constant.int 64 + %int1_4713 = torch.constant.int 1 + %3974 = torch.aten.slice.Tensor %3852, %int3_4710, %int0_4711, %int64_4712, %int1_4713 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %3974, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_4714 = torch.constant.int 3 + %int64_4715 = torch.constant.int 64 + %int9223372036854775807_4716 = torch.constant.int 9223372036854775807 + %int1_4717 = torch.constant.int 1 + %3975 = torch.aten.slice.Tensor %3852, %int3_4714, %int64_4715, %int9223372036854775807_4716, %int1_4717 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %3975, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %3976 = torch.aten.neg %3975 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %3976, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %3977 = torch.prim.ListConstruct %3976, %3974 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_4718 = torch.constant.int -1 + %3978 = torch.aten.cat %3977, %int-1_4718 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3978, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %3979 = torch.aten.mul.Tensor %3978, %3972 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3979, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_4719 = torch.constant.int 1 + %3980 = torch.aten.add.Tensor %3973, %3979, %int1_4719 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3980, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_4720 = torch.constant.int 32 + %3981 = torch.aten.mul.Scalar %arg2, %int32_4720 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3981, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int13 = torch.constant.int 13 + %int1_4721 = torch.constant.int 1 + %3982 = torch.aten.add.Scalar %3981, %int13, %int1_4721 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3982, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_4722 = torch.constant.int 2 + %3983 = torch.aten.mul.Scalar %3982, %int2_4722 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3983, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_4723 = torch.constant.int 0 + %int1_4724 = torch.constant.int 1 + %3984 = torch.aten.add.Scalar %3983, %int0_4723, %int1_4724 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %3984, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %3985 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %3986 = torch.aten.view %3984, %3985 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %3986, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_4725 = torch.constant.int 4 + %int32_4726 = torch.constant.int 32 + %int8_4727 = torch.constant.int 8 + %int128_4728 = torch.constant.int 128 + %3987 = torch.prim.ListConstruct %int4_4725, %296, %int32_4726, %int8_4727, %int128_4728 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3988 = torch.aten.view %3980, %3987 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3988, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int32_4729 = torch.constant.int 32 - %int128_4730 = torch.constant.int 128 - %3982 = torch.prim.ListConstruct %int4_4728, %3921, %int32_4729, %int128_4730 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3983 = torch.aten._unsafe_view %3981, %3982 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3983, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_4731 = torch.constant.int -2 - %3984 = torch.aten.unsqueeze %3880, %int-2_4731 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3984, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int8_4730 = torch.constant.int 8 + %int128_4731 = torch.constant.int 128 + %3989 = torch.prim.ListConstruct %504, %int32_4729, %int8_4730, %int128_4731 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3990 = torch.aten.view %3988, %3989 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %3990, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> %int1_4732 = torch.constant.int 1 - %3985 = torch.aten.size.int %3874, %int1_4732 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_4733 = torch.constant.int 4 - %int8_4734 = torch.constant.int 8 - %int4_4735 = torch.constant.int 4 - %int128_4736 = torch.constant.int 128 - %3986 = torch.prim.ListConstruct %int4_4733, %3985, %int8_4734, %int4_4735, %int128_4736 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4737 = torch.constant.bool false - %3987 = torch.aten.expand %3984, %3986, %false_4737 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3987, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_4738 = torch.constant.int 0 - %3988 = torch.aten.clone %3987, %int0_4738 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3988, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4739 = torch.constant.int 4 - %int32_4740 = torch.constant.int 32 - %int128_4741 = torch.constant.int 128 - %3989 = torch.prim.ListConstruct %int4_4739, %3985, %int32_4740, %int128_4741 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3990 = torch.aten._unsafe_view %3988, %3989 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3990, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_4742 = torch.constant.int 1 - %int2_4743 = torch.constant.int 2 - %3991 = torch.aten.transpose.int %3908, %int1_4742, %int2_4743 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3991, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_4744 = torch.constant.int 1 + %int2_4733 = torch.constant.int 2 + %3991 = torch.aten.transpose.int %3990, %int1_4732, %int2_4733 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3991, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_4734 = torch.constant.int 5 + %3992 = torch.prims.convert_element_type %3991, %int5_4734 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3992, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_4735 = torch.constant.int 32 + %int2_4736 = torch.constant.int 2 + %int8_4737 = torch.constant.int 8 + %int32_4738 = torch.constant.int 32 + %int128_4739 = torch.constant.int 128 + %3993 = torch.prim.ListConstruct %297, %int32_4735, %int2_4736, %int8_4737, %int32_4738, %int128_4739 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3994 = torch.aten.view %3756, %3993 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3994, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_4740 = torch.constant.int 8 + %int32_4741 = torch.constant.int 32 + %int128_4742 = torch.constant.int 128 + %3995 = torch.prim.ListConstruct %497, %int8_4740, %int32_4741, %int128_4742 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3996 = torch.aten.view %3994, %3995 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3996, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %3997 = torch.prim.ListConstruct %3986 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_4743 = torch.constant.bool false + %3998 = torch.aten.index_put %3996, %3997, %3992, %false_4743 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %3998, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_4744 = torch.constant.int 32 %int2_4745 = torch.constant.int 2 - %3992 = torch.aten.transpose.int %3983, %int1_4744, %int2_4745 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3992, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_4746 = torch.constant.int 1 - %int2_4747 = torch.constant.int 2 - %3993 = torch.aten.transpose.int %3990, %int1_4746, %int2_4747 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3993, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_4748 = torch.constant.float 0.000000e+00 - %true_4749 = torch.constant.bool true - %none_4750 = torch.constant.none - %none_4751 = torch.constant.none - %3994:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3991, %3992, %3993, %float0.000000e00_4748, %true_4749, %none_4750, %none_4751) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %3994#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_4752 = torch.constant.int 1 - %int2_4753 = torch.constant.int 2 - %3995 = torch.aten.transpose.int %3994#0, %int1_4752, %int2_4753 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3995, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_4754 = torch.constant.int 4 - %int4096_4755 = torch.constant.int 4096 - %3996 = torch.prim.ListConstruct %int4_4754, %3893, %int4096_4755 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3997 = torch.aten.view %3995, %3996 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %3997, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4756 = torch.constant.int -2 - %int-1_4757 = torch.constant.int -1 - %3998 = torch.aten.transpose.int %167, %int-2_4756, %int-1_4757 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_4758 = torch.constant.int 4 - %3999 = torch.aten.mul.int %int4_4758, %3893 : !torch.int, !torch.int -> !torch.int - %int4096_4759 = torch.constant.int 4096 - %4000 = torch.prim.ListConstruct %3999, %int4096_4759 : (!torch.int, !torch.int) -> !torch.list - %4001 = torch.aten.view %3997, %4000 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4001, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4002 = torch.aten.mm %4001, %3998 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4002, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_4760 = torch.constant.int 4 - %int4096_4761 = torch.constant.int 4096 - %4003 = torch.prim.ListConstruct %int4_4760, %3893, %int4096_4761 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4004 = torch.aten.view %4002, %4003 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4004, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int8_4746 = torch.constant.int 8 + %int32_4747 = torch.constant.int 32 + %int128_4748 = torch.constant.int 128 + %3999 = torch.prim.ListConstruct %297, %int32_4744, %int2_4745, %int8_4746, %int32_4747, %int128_4748 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4000 = torch.aten.view %3998, %3999 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4000, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_4749 = torch.constant.int 2097152 + %4001 = torch.prim.ListConstruct %297, %int2097152_4749 : (!torch.int, !torch.int) -> !torch.list + %4002 = torch.aten.view %4000, %4001 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4002, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_4750 = torch.constant.int 32 + %int2_4751 = torch.constant.int 2 + %int8_4752 = torch.constant.int 8 + %int32_4753 = torch.constant.int 32 + %int128_4754 = torch.constant.int 128 + %4003 = torch.prim.ListConstruct %297, %int32_4750, %int2_4751, %int8_4752, %int32_4753, %int128_4754 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4004 = torch.aten.view %4002, %4003 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4004, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_4755 = torch.constant.int 8 + %int32_4756 = torch.constant.int 32 + %int128_4757 = torch.constant.int 128 + %4005 = torch.prim.ListConstruct %497, %int8_4755, %int32_4756, %int128_4757 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4006 = torch.aten.view %4004, %4005 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4006, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_4758 = torch.constant.int 32 + %4007 = torch.aten.mul.Scalar %arg2, %int32_4758 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4007, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int13_4759 = torch.constant.int 13 + %int1_4760 = torch.constant.int 1 + %4008 = torch.aten.add.Scalar %4007, %int13_4759, %int1_4760 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4008, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_4761 = torch.constant.int 2 + %4009 = torch.aten.mul.Scalar %4008, %int2_4761 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4009, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> %int1_4762 = torch.constant.int 1 - %4005 = torch.aten.add.Tensor %3843, %4004, %int1_4762 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4005, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_4763 = torch.constant.int 6 - %4006 = torch.prims.convert_element_type %4005, %int6_4763 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4006, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_4764 = torch.constant.int 2 - %4007 = torch.aten.pow.Tensor_Scalar %4006, %int2_4764 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4007, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_4765 = torch.constant.int -1 - %4008 = torch.prim.ListConstruct %int-1_4765 : (!torch.int) -> !torch.list - %true_4766 = torch.constant.bool true - %none_4767 = torch.constant.none - %4009 = torch.aten.mean.dim %4007, %4008, %true_4766, %none_4767 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4009, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_4768 = torch.constant.float 9.9999997473787516E-6 - %int1_4769 = torch.constant.int 1 - %4010 = torch.aten.add.Scalar %4009, %float9.999990e-06_4768, %int1_4769 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4010, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4011 = torch.aten.rsqrt %4010 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4011, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4012 = torch.aten.mul.Tensor %4006, %4011 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4012, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4770 = torch.constant.int 5 - %4013 = torch.prims.convert_element_type %4012, %int5_4770 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4013, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %4014 = torch.aten.mul.Tensor %168, %4013 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4014, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4771 = torch.constant.int 5 - %4015 = torch.prims.convert_element_type %4014, %int5_4771 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4015, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4772 = torch.constant.int -2 - %int-1_4773 = torch.constant.int -1 - %4016 = torch.aten.transpose.int %169, %int-2_4772, %int-1_4773 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_4774 = torch.constant.int 4 - %4017 = torch.aten.mul.int %int4_4774, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4775 = torch.constant.int 4096 - %4018 = torch.prim.ListConstruct %4017, %int4096_4775 : (!torch.int, !torch.int) -> !torch.list - %4019 = torch.aten.view %4015, %4018 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4019, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4020 = torch.aten.mm %4019, %4016 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4020, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_4776 = torch.constant.int 4 - %int14336_4777 = torch.constant.int 14336 - %4021 = torch.prim.ListConstruct %int4_4776, %306, %int14336_4777 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4022 = torch.aten.view %4020, %4021 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4022, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %4023 = torch.aten.silu %4022 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4023, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_4778 = torch.constant.int -2 - %int-1_4779 = torch.constant.int -1 - %4024 = torch.aten.transpose.int %170, %int-2_4778, %int-1_4779 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_4780 = torch.constant.int 4 - %4025 = torch.aten.mul.int %int4_4780, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4781 = torch.constant.int 4096 - %4026 = torch.prim.ListConstruct %4025, %int4096_4781 : (!torch.int, !torch.int) -> !torch.list - %4027 = torch.aten.view %4015, %4026 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4027, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4028 = torch.aten.mm %4027, %4024 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4028, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int1_4763 = torch.constant.int 1 + %4010 = torch.aten.add.Scalar %4009, %int1_4762, %int1_4763 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4010, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %4011 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %4012 = torch.aten.view %4010, %4011 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %4012, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_4764 = torch.constant.int 4 + %int32_4765 = torch.constant.int 32 + %int8_4766 = torch.constant.int 8 + %int128_4767 = torch.constant.int 128 + %4013 = torch.prim.ListConstruct %int4_4764, %296, %int32_4765, %int8_4766, %int128_4767 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4014 = torch.aten.view %3854, %4013 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4014, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_4768 = torch.constant.int 32 + %int8_4769 = torch.constant.int 8 + %int128_4770 = torch.constant.int 128 + %4015 = torch.prim.ListConstruct %504, %int32_4768, %int8_4769, %int128_4770 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4016 = torch.aten.view %4014, %4015 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %4016, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_4771 = torch.constant.int 1 + %int2_4772 = torch.constant.int 2 + %4017 = torch.aten.transpose.int %4016, %int1_4771, %int2_4772 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4017, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_4773 = torch.constant.int 5 + %4018 = torch.prims.convert_element_type %4017, %int5_4773 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4018, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %4019 = torch.prim.ListConstruct %4012 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_4774 = torch.constant.bool false + %4020 = torch.aten.index_put %4006, %4019, %4018, %false_4774 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4020, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_4775 = torch.constant.int 32 + %int2_4776 = torch.constant.int 2 + %int8_4777 = torch.constant.int 8 + %int32_4778 = torch.constant.int 32 + %int128_4779 = torch.constant.int 128 + %4021 = torch.prim.ListConstruct %297, %int32_4775, %int2_4776, %int8_4777, %int32_4778, %int128_4779 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4022 = torch.aten.view %4020, %4021 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4022, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_4780 = torch.constant.int 2097152 + %4023 = torch.prim.ListConstruct %297, %int2097152_4780 : (!torch.int, !torch.int) -> !torch.list + %4024 = torch.aten.view %4022, %4023 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4024, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_4781 = torch.constant.int -2 + %4025 = torch.aten.unsqueeze %3980, %int-2_4781 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4025, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> %int4_4782 = torch.constant.int 4 - %int14336_4783 = torch.constant.int 14336 - %4029 = torch.prim.ListConstruct %int4_4782, %306, %int14336_4783 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4030 = torch.aten.view %4028, %4029 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4030, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %4031 = torch.aten.mul.Tensor %4023, %4030 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4031, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_4784 = torch.constant.int -2 - %int-1_4785 = torch.constant.int -1 - %4032 = torch.aten.transpose.int %171, %int-2_4784, %int-1_4785 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_4786 = torch.constant.int 1 - %4033 = torch.aten.size.int %4022, %int1_4786 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_4787 = torch.constant.int 4 - %4034 = torch.aten.mul.int %int4_4787, %4033 : !torch.int, !torch.int -> !torch.int - %int14336_4788 = torch.constant.int 14336 - %4035 = torch.prim.ListConstruct %4034, %int14336_4788 : (!torch.int, !torch.int) -> !torch.list - %4036 = torch.aten.view %4031, %4035 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4036, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %4037 = torch.aten.mm %4036, %4032 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4037, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_4789 = torch.constant.int 4 - %int4096_4790 = torch.constant.int 4096 - %4038 = torch.prim.ListConstruct %int4_4789, %4033, %int4096_4790 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4039 = torch.aten.view %4037, %4038 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4039, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_4791 = torch.constant.int 1 - %4040 = torch.aten.add.Tensor %4005, %4039, %int1_4791 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4040, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_4792 = torch.constant.int 6 - %4041 = torch.prims.convert_element_type %4040, %int6_4792 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4041, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_4793 = torch.constant.int 2 - %4042 = torch.aten.pow.Tensor_Scalar %4041, %int2_4793 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4042, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_4794 = torch.constant.int -1 - %4043 = torch.prim.ListConstruct %int-1_4794 : (!torch.int) -> !torch.list - %true_4795 = torch.constant.bool true - %none_4796 = torch.constant.none - %4044 = torch.aten.mean.dim %4042, %4043, %true_4795, %none_4796 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4044, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_4797 = torch.constant.float 9.9999997473787516E-6 - %int1_4798 = torch.constant.int 1 - %4045 = torch.aten.add.Scalar %4044, %float9.999990e-06_4797, %int1_4798 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4045, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4046 = torch.aten.rsqrt %4045 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4046, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4047 = torch.aten.mul.Tensor %4041, %4046 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4047, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4799 = torch.constant.int 5 - %4048 = torch.prims.convert_element_type %4047, %int5_4799 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4048, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %4049 = torch.aten.mul.Tensor %172, %4048 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4049, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_4800 = torch.constant.int 5 - %4050 = torch.prims.convert_element_type %4049, %int5_4800 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4050, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4801 = torch.constant.int -2 - %int-1_4802 = torch.constant.int -1 - %4051 = torch.aten.transpose.int %173, %int-2_4801, %int-1_4802 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_4803 = torch.constant.int 4 - %4052 = torch.aten.mul.int %int4_4803, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4804 = torch.constant.int 4096 - %4053 = torch.prim.ListConstruct %4052, %int4096_4804 : (!torch.int, !torch.int) -> !torch.list - %4054 = torch.aten.view %4050, %4053 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4054, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4055 = torch.aten.mm %4054, %4051 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4055, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_4805 = torch.constant.int 4 - %int4096_4806 = torch.constant.int 4096 - %4056 = torch.prim.ListConstruct %int4_4805, %306, %int4096_4806 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4057 = torch.aten.view %4055, %4056 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4057, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_4807 = torch.constant.int -2 - %int-1_4808 = torch.constant.int -1 - %4058 = torch.aten.transpose.int %174, %int-2_4807, %int-1_4808 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_4809 = torch.constant.int 4 - %4059 = torch.aten.mul.int %int4_4809, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4810 = torch.constant.int 4096 - %4060 = torch.prim.ListConstruct %4059, %int4096_4810 : (!torch.int, !torch.int) -> !torch.list - %4061 = torch.aten.view %4050, %4060 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4061, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4062 = torch.aten.mm %4061, %4058 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %4062, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_4811 = torch.constant.int 4 - %int1024_4812 = torch.constant.int 1024 - %4063 = torch.prim.ListConstruct %int4_4811, %306, %int1024_4812 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4064 = torch.aten.view %4062, %4063 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %4064, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_4813 = torch.constant.int -2 - %int-1_4814 = torch.constant.int -1 - %4065 = torch.aten.transpose.int %175, %int-2_4813, %int-1_4814 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_4815 = torch.constant.int 4 - %4066 = torch.aten.mul.int %int4_4815, %306 : !torch.int, !torch.int -> !torch.int - %int4096_4816 = torch.constant.int 4096 - %4067 = torch.prim.ListConstruct %4066, %int4096_4816 : (!torch.int, !torch.int) -> !torch.list - %4068 = torch.aten.view %4050, %4067 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4068, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4069 = torch.aten.mm %4068, %4065 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %4069, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_4817 = torch.constant.int 4 - %int1024_4818 = torch.constant.int 1024 - %4070 = torch.prim.ListConstruct %int4_4817, %306, %int1024_4818 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4071 = torch.aten.view %4069, %4070 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %4071, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_4819 = torch.constant.int 4 - %int32_4820 = torch.constant.int 32 - %int128_4821 = torch.constant.int 128 - %4072 = torch.prim.ListConstruct %int4_4819, %306, %int32_4820, %int128_4821 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4073 = torch.aten.view %4057, %4072 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4073, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_4822 = torch.constant.int 4 - %int8_4823 = torch.constant.int 8 - %int128_4824 = torch.constant.int 128 - %4074 = torch.prim.ListConstruct %int4_4822, %306, %int8_4823, %int128_4824 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4075 = torch.aten.view %4064, %4074 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4075, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_4825 = torch.constant.int 4 - %int8_4826 = torch.constant.int 8 - %int128_4827 = torch.constant.int 128 - %4076 = torch.prim.ListConstruct %int4_4825, %306, %int8_4826, %int128_4827 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4077 = torch.aten.view %4071, %4076 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4077, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_4828 = torch.constant.int 131072 - %none_4829 = torch.constant.none - %none_4830 = torch.constant.none - %cpu_4831 = torch.constant.device "cpu" - %false_4832 = torch.constant.bool false - %4078 = torch.aten.arange %int131072_4828, %none_4829, %none_4830, %cpu_4831, %false_4832 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_4833 = torch.constant.int 0 - %int128_4834 = torch.constant.int 128 - %none_4835 = torch.constant.none - %none_4836 = torch.constant.none - %cpu_4837 = torch.constant.device "cpu" - %false_4838 = torch.constant.bool false - %4079 = torch.aten.arange.start %int0_4833, %int128_4834, %none_4835, %none_4836, %cpu_4837, %false_4838 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_4839 = torch.constant.int 2 - %4080 = torch.aten.floor_divide.Scalar %4079, %int2_4839 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_4840 = torch.constant.int 6 - %4081 = torch.prims.convert_element_type %4080, %int6_4840 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_4841 = torch.constant.int 128 - %4082 = torch.aten.div.Scalar %4081, %int128_4841 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_4842 = torch.constant.float 2.000000e+00 - %4083 = torch.aten.mul.Scalar %4082, %float2.000000e00_4842 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_4843 = torch.constant.float 5.000000e+05 - %4084 = torch.aten.pow.Scalar %float5.000000e05_4843, %4083 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %4085 = torch.aten.reciprocal %4084 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_4844 = torch.constant.float 1.000000e+00 - %4086 = torch.aten.mul.Scalar %4085, %float1.000000e00_4844 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_4845 = torch.constant.int 1 - %4087 = torch.aten.unsqueeze %4078, %int1_4845 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_4846 = torch.constant.int 0 - %4088 = torch.aten.unsqueeze %4086, %int0_4846 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %4089 = torch.aten.mul.Tensor %4087, %4088 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_4847 = torch.constant.int 1 - %4090 = torch.aten.size.int %4057, %int1_4847 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_4848 = torch.constant.int 0 - %4091 = torch.aten.add.int %int0_4848, %4090 : !torch.int, !torch.int -> !torch.int - %int0_4849 = torch.constant.int 0 - %int0_4850 = torch.constant.int 0 - %int1_4851 = torch.constant.int 1 - %4092 = torch.aten.slice.Tensor %4089, %int0_4849, %int0_4850, %4091, %int1_4851 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4092, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4852 = torch.constant.int 1 - %int0_4853 = torch.constant.int 0 - %int9223372036854775807_4854 = torch.constant.int 9223372036854775807 + %int8_4783 = torch.constant.int 8 + %int4_4784 = torch.constant.int 4 + %int128_4785 = torch.constant.int 128 + %4026 = torch.prim.ListConstruct %int4_4782, %298, %int8_4783, %int4_4784, %int128_4785 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_4786 = torch.constant.bool false + %4027 = torch.aten.expand %4025, %4026, %false_4786 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4027, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_4787 = torch.constant.int 0 + %4028 = torch.aten.clone %4027, %int0_4787 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4028, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_4788 = torch.constant.int 4 + %int32_4789 = torch.constant.int 32 + %int128_4790 = torch.constant.int 128 + %4029 = torch.prim.ListConstruct %int4_4788, %298, %int32_4789, %int128_4790 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4030 = torch.aten._unsafe_view %4028, %4029 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4030, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_4791 = torch.constant.int -2 + %4031 = torch.aten.unsqueeze %3854, %int-2_4791 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4031, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_4792 = torch.constant.int 4 + %int8_4793 = torch.constant.int 8 + %int4_4794 = torch.constant.int 4 + %int128_4795 = torch.constant.int 128 + %4032 = torch.prim.ListConstruct %int4_4792, %298, %int8_4793, %int4_4794, %int128_4795 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_4796 = torch.constant.bool false + %4033 = torch.aten.expand %4031, %4032, %false_4796 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4033, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_4797 = torch.constant.int 0 + %4034 = torch.aten.clone %4033, %int0_4797 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4034, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_4798 = torch.constant.int 4 + %int32_4799 = torch.constant.int 32 + %int128_4800 = torch.constant.int 128 + %4035 = torch.prim.ListConstruct %int4_4798, %298, %int32_4799, %int128_4800 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4036 = torch.aten._unsafe_view %4034, %4035 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4036, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_4801 = torch.constant.int 1 + %int2_4802 = torch.constant.int 2 + %4037 = torch.aten.transpose.int %3917, %int1_4801, %int2_4802 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4037, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_4803 = torch.constant.int 1 + %int2_4804 = torch.constant.int 2 + %4038 = torch.aten.transpose.int %4030, %int1_4803, %int2_4804 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4038, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_4805 = torch.constant.int 1 + %int2_4806 = torch.constant.int 2 + %4039 = torch.aten.transpose.int %4036, %int1_4805, %int2_4806 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4039, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_4807 = torch.constant.float 0.000000e+00 + %false_4808 = torch.constant.bool false + %none_4809 = torch.constant.none + %4040:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4037, %4038, %4039, %float0.000000e00_4807, %false_4808, %327, %none_4809) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %4040#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_4810 = torch.constant.int 1 + %int2_4811 = torch.constant.int 2 + %4041 = torch.aten.transpose.int %4040#0, %int1_4810, %int2_4811 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4041, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_4812 = torch.constant.int 4 + %int4096_4813 = torch.constant.int 4096 + %4042 = torch.prim.ListConstruct %int4_4812, %298, %int4096_4813 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4043 = torch.aten.view %4041, %4042 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4043, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_4814 = torch.constant.int -2 + %int-1_4815 = torch.constant.int -1 + %4044 = torch.aten.transpose.int %123, %int-2_4814, %int-1_4815 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_4816 = torch.constant.int 5 + %4045 = torch.prims.convert_element_type %4044, %int5_4816 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_4817 = torch.constant.int 4096 + %4046 = torch.prim.ListConstruct %342, %int4096_4817 : (!torch.int, !torch.int) -> !torch.list + %4047 = torch.aten.view %4043, %4046 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4047, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4048 = torch.aten.mm %4047, %4045 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4048, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_4818 = torch.constant.int 4 + %int4096_4819 = torch.constant.int 4096 + %4049 = torch.prim.ListConstruct %int4_4818, %298, %int4096_4819 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4050 = torch.aten.view %4048, %4049 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4050, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_4820 = torch.constant.int 1 + %4051 = torch.aten.add.Tensor %3817, %4050, %int1_4820 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4051, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_4821 = torch.constant.int 6 + %4052 = torch.prims.convert_element_type %4051, %int6_4821 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4052, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_4822 = torch.constant.int 2 + %4053 = torch.aten.pow.Tensor_Scalar %4052, %int2_4822 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4053, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_4823 = torch.constant.int -1 + %4054 = torch.prim.ListConstruct %int-1_4823 : (!torch.int) -> !torch.list + %true_4824 = torch.constant.bool true + %none_4825 = torch.constant.none + %4055 = torch.aten.mean.dim %4053, %4054, %true_4824, %none_4825 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4055, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_4826 = torch.constant.float 9.9999997473787516E-6 + %int1_4827 = torch.constant.int 1 + %4056 = torch.aten.add.Scalar %4055, %float9.999990e-06_4826, %int1_4827 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4056, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4057 = torch.aten.rsqrt %4056 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4057, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4058 = torch.aten.mul.Tensor %4052, %4057 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4058, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_4828 = torch.constant.int 5 + %4059 = torch.prims.convert_element_type %4058, %int5_4828 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4059, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %4060 = torch.aten.mul.Tensor %124, %4059 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4060, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_4829 = torch.constant.int 5 + %4061 = torch.prims.convert_element_type %4060, %int5_4829 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4061, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_4830 = torch.constant.int -2 + %int-1_4831 = torch.constant.int -1 + %4062 = torch.aten.transpose.int %125, %int-2_4830, %int-1_4831 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_4832 = torch.constant.int 5 + %4063 = torch.prims.convert_element_type %4062, %int5_4832 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_4833 = torch.constant.int 4096 + %4064 = torch.prim.ListConstruct %342, %int4096_4833 : (!torch.int, !torch.int) -> !torch.list + %4065 = torch.aten.view %4061, %4064 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4065, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4066 = torch.aten.mm %4065, %4063 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %4066, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_4834 = torch.constant.int 4 + %int14336_4835 = torch.constant.int 14336 + %4067 = torch.prim.ListConstruct %int4_4834, %298, %int14336_4835 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4068 = torch.aten.view %4066, %4067 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4068, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %4069 = torch.aten.silu %4068 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4069, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_4836 = torch.constant.int -2 + %int-1_4837 = torch.constant.int -1 + %4070 = torch.aten.transpose.int %126, %int-2_4836, %int-1_4837 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_4838 = torch.constant.int 5 + %4071 = torch.prims.convert_element_type %4070, %int5_4838 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_4839 = torch.constant.int 4096 + %4072 = torch.prim.ListConstruct %342, %int4096_4839 : (!torch.int, !torch.int) -> !torch.list + %4073 = torch.aten.view %4061, %4072 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4073, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4074 = torch.aten.mm %4073, %4071 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %4074, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_4840 = torch.constant.int 4 + %int14336_4841 = torch.constant.int 14336 + %4075 = torch.prim.ListConstruct %int4_4840, %298, %int14336_4841 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4076 = torch.aten.view %4074, %4075 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4076, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %4077 = torch.aten.mul.Tensor %4069, %4076 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4077, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_4842 = torch.constant.int -2 + %int-1_4843 = torch.constant.int -1 + %4078 = torch.aten.transpose.int %127, %int-2_4842, %int-1_4843 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_4844 = torch.constant.int 5 + %4079 = torch.prims.convert_element_type %4078, %int5_4844 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_4845 = torch.constant.int 14336 + %4080 = torch.prim.ListConstruct %342, %int14336_4845 : (!torch.int, !torch.int) -> !torch.list + %4081 = torch.aten.view %4077, %4080 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %4081, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %4082 = torch.aten.mm %4081, %4079 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4082, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_4846 = torch.constant.int 4 + %int4096_4847 = torch.constant.int 4096 + %4083 = torch.prim.ListConstruct %int4_4846, %298, %int4096_4847 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4084 = torch.aten.view %4082, %4083 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4084, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_4848 = torch.constant.int 1 + %4085 = torch.aten.add.Tensor %4051, %4084, %int1_4848 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4085, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_4849 = torch.constant.int 6 + %4086 = torch.prims.convert_element_type %4085, %int6_4849 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4086, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_4850 = torch.constant.int 2 + %4087 = torch.aten.pow.Tensor_Scalar %4086, %int2_4850 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4087, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_4851 = torch.constant.int -1 + %4088 = torch.prim.ListConstruct %int-1_4851 : (!torch.int) -> !torch.list + %true_4852 = torch.constant.bool true + %none_4853 = torch.constant.none + %4089 = torch.aten.mean.dim %4087, %4088, %true_4852, %none_4853 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4089, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_4854 = torch.constant.float 9.9999997473787516E-6 %int1_4855 = torch.constant.int 1 - %4093 = torch.aten.slice.Tensor %4092, %int1_4852, %int0_4853, %int9223372036854775807_4854, %int1_4855 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4093, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4856 = torch.constant.int 1 - %int0_4857 = torch.constant.int 0 - %int9223372036854775807_4858 = torch.constant.int 9223372036854775807 - %int1_4859 = torch.constant.int 1 - %4094 = torch.aten.slice.Tensor %4093, %int1_4856, %int0_4857, %int9223372036854775807_4858, %int1_4859 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4094, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_4860 = torch.constant.int 0 - %4095 = torch.aten.unsqueeze %4094, %int0_4860 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4095, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_4861 = torch.constant.int 1 - %int0_4862 = torch.constant.int 0 - %int9223372036854775807_4863 = torch.constant.int 9223372036854775807 - %int1_4864 = torch.constant.int 1 - %4096 = torch.aten.slice.Tensor %4095, %int1_4861, %int0_4862, %int9223372036854775807_4863, %int1_4864 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4096, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_4865 = torch.constant.int 2 - %int0_4866 = torch.constant.int 0 - %int9223372036854775807_4867 = torch.constant.int 9223372036854775807 - %int1_4868 = torch.constant.int 1 - %4097 = torch.aten.slice.Tensor %4096, %int2_4865, %int0_4866, %int9223372036854775807_4867, %int1_4868 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4097, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_4869 = torch.constant.int 4 - %int1_4870 = torch.constant.int 1 - %int1_4871 = torch.constant.int 1 - %4098 = torch.prim.ListConstruct %int4_4869, %int1_4870, %int1_4871 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4099 = torch.aten.repeat %4097, %4098 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %4099, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_4872 = torch.constant.int 6 - %4100 = torch.prims.convert_element_type %4073, %int6_4872 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %4100, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %4101 = torch_c.to_builtin_tensor %4100 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %4102 = torch_c.to_builtin_tensor %4099 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %4103 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%4101, %4102) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %4104 = torch_c.from_builtin_tensor %4103 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %4104, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_4873 = torch.constant.int 5 - %4105 = torch.prims.convert_element_type %4104, %int5_4873 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4105, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_4874 = torch.constant.int 131072 - %none_4875 = torch.constant.none - %none_4876 = torch.constant.none - %cpu_4877 = torch.constant.device "cpu" - %false_4878 = torch.constant.bool false - %4106 = torch.aten.arange %int131072_4874, %none_4875, %none_4876, %cpu_4877, %false_4878 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_4879 = torch.constant.int 0 - %int128_4880 = torch.constant.int 128 - %none_4881 = torch.constant.none - %none_4882 = torch.constant.none - %cpu_4883 = torch.constant.device "cpu" - %false_4884 = torch.constant.bool false - %4107 = torch.aten.arange.start %int0_4879, %int128_4880, %none_4881, %none_4882, %cpu_4883, %false_4884 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_4885 = torch.constant.int 2 - %4108 = torch.aten.floor_divide.Scalar %4107, %int2_4885 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_4886 = torch.constant.int 6 - %4109 = torch.prims.convert_element_type %4108, %int6_4886 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_4887 = torch.constant.int 128 - %4110 = torch.aten.div.Scalar %4109, %int128_4887 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_4888 = torch.constant.float 2.000000e+00 - %4111 = torch.aten.mul.Scalar %4110, %float2.000000e00_4888 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_4889 = torch.constant.float 5.000000e+05 - %4112 = torch.aten.pow.Scalar %float5.000000e05_4889, %4111 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %4113 = torch.aten.reciprocal %4112 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_4890 = torch.constant.float 1.000000e+00 - %4114 = torch.aten.mul.Scalar %4113, %float1.000000e00_4890 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_4891 = torch.constant.int 1 - %4115 = torch.aten.unsqueeze %4106, %int1_4891 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_4892 = torch.constant.int 0 - %4116 = torch.aten.unsqueeze %4114, %int0_4892 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %4117 = torch.aten.mul.Tensor %4115, %4116 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_4893 = torch.constant.int 1 - %4118 = torch.aten.size.int %4064, %int1_4893 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_4894 = torch.constant.int 0 - %4119 = torch.aten.add.int %int0_4894, %4118 : !torch.int, !torch.int -> !torch.int - %int0_4895 = torch.constant.int 0 - %int0_4896 = torch.constant.int 0 - %int1_4897 = torch.constant.int 1 - %4120 = torch.aten.slice.Tensor %4117, %int0_4895, %int0_4896, %4119, %int1_4897 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4120, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4898 = torch.constant.int 1 - %int0_4899 = torch.constant.int 0 - %int9223372036854775807_4900 = torch.constant.int 9223372036854775807 - %int1_4901 = torch.constant.int 1 - %4121 = torch.aten.slice.Tensor %4120, %int1_4898, %int0_4899, %int9223372036854775807_4900, %int1_4901 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4121, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_4902 = torch.constant.int 1 - %int0_4903 = torch.constant.int 0 - %int9223372036854775807_4904 = torch.constant.int 9223372036854775807 + %4090 = torch.aten.add.Scalar %4089, %float9.999990e-06_4854, %int1_4855 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4090, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4091 = torch.aten.rsqrt %4090 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4091, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4092 = torch.aten.mul.Tensor %4086, %4091 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4092, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_4856 = torch.constant.int 5 + %4093 = torch.prims.convert_element_type %4092, %int5_4856 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4093, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %4094 = torch.aten.mul.Tensor %128, %4093 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4094, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_4857 = torch.constant.int 5 + %4095 = torch.prims.convert_element_type %4094, %int5_4857 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4095, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_4858 = torch.constant.int -2 + %int-1_4859 = torch.constant.int -1 + %4096 = torch.aten.transpose.int %129, %int-2_4858, %int-1_4859 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_4860 = torch.constant.int 5 + %4097 = torch.prims.convert_element_type %4096, %int5_4860 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_4861 = torch.constant.int 4096 + %4098 = torch.prim.ListConstruct %342, %int4096_4861 : (!torch.int, !torch.int) -> !torch.list + %4099 = torch.aten.view %4095, %4098 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4099, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4100 = torch.aten.mm %4099, %4097 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4100, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_4862 = torch.constant.int 4 + %int4096_4863 = torch.constant.int 4096 + %4101 = torch.prim.ListConstruct %int4_4862, %298, %int4096_4863 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4102 = torch.aten.view %4100, %4101 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4102, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_4864 = torch.constant.int -2 + %int-1_4865 = torch.constant.int -1 + %4103 = torch.aten.transpose.int %130, %int-2_4864, %int-1_4865 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_4866 = torch.constant.int 5 + %4104 = torch.prims.convert_element_type %4103, %int5_4866 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_4867 = torch.constant.int 4096 + %4105 = torch.prim.ListConstruct %342, %int4096_4867 : (!torch.int, !torch.int) -> !torch.list + %4106 = torch.aten.view %4095, %4105 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4106, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4107 = torch.aten.mm %4106, %4104 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %4107, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_4868 = torch.constant.int 4 + %int1024_4869 = torch.constant.int 1024 + %4108 = torch.prim.ListConstruct %int4_4868, %298, %int1024_4869 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4109 = torch.aten.view %4107, %4108 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %4109, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_4870 = torch.constant.int -2 + %int-1_4871 = torch.constant.int -1 + %4110 = torch.aten.transpose.int %131, %int-2_4870, %int-1_4871 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_4872 = torch.constant.int 5 + %4111 = torch.prims.convert_element_type %4110, %int5_4872 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_4873 = torch.constant.int 4096 + %4112 = torch.prim.ListConstruct %342, %int4096_4873 : (!torch.int, !torch.int) -> !torch.list + %4113 = torch.aten.view %4095, %4112 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4113, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4114 = torch.aten.mm %4113, %4111 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %4114, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_4874 = torch.constant.int 4 + %int1024_4875 = torch.constant.int 1024 + %4115 = torch.prim.ListConstruct %int4_4874, %298, %int1024_4875 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4116 = torch.aten.view %4114, %4115 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %4116, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_4876 = torch.constant.int 4 + %int32_4877 = torch.constant.int 32 + %int128_4878 = torch.constant.int 128 + %4117 = torch.prim.ListConstruct %int4_4876, %298, %int32_4877, %int128_4878 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4118 = torch.aten.view %4102, %4117 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4118, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_4879 = torch.constant.int 4 + %int8_4880 = torch.constant.int 8 + %int128_4881 = torch.constant.int 128 + %4119 = torch.prim.ListConstruct %int4_4879, %298, %int8_4880, %int128_4881 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4120 = torch.aten.view %4109, %4119 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4120, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_4882 = torch.constant.int 4 + %int8_4883 = torch.constant.int 8 + %int128_4884 = torch.constant.int 128 + %4121 = torch.prim.ListConstruct %int4_4882, %298, %int8_4883, %int128_4884 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4122 = torch.aten.view %4116, %4121 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4122, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_4885 = torch.constant.int 131072 + %none_4886 = torch.constant.none + %none_4887 = torch.constant.none + %cpu_4888 = torch.constant.device "cpu" + %false_4889 = torch.constant.bool false + %4123 = torch.aten.arange %int131072_4885, %none_4886, %none_4887, %cpu_4888, %false_4889 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_4890 = torch.constant.int 0 + %int128_4891 = torch.constant.int 128 + %int2_4892 = torch.constant.int 2 + %int4_4893 = torch.constant.int 4 + %none_4894 = torch.constant.none + %cpu_4895 = torch.constant.device "cpu" + %false_4896 = torch.constant.bool false + %4124 = torch.aten.arange.start_step %int0_4890, %int128_4891, %int2_4892, %int4_4893, %none_4894, %cpu_4895, %false_4896 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_4897 = torch.constant.int 6 + %4125 = torch.prims.convert_element_type %4124, %int6_4897 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_4898 = torch.constant.int 128 + %4126 = torch.aten.div.Scalar %4125, %int128_4898 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_4899 = torch.constant.float 5.000000e+05 + %4127 = torch.aten.pow.Scalar %float5.000000e05_4899, %4126 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4128 = torch.aten.reciprocal %4127 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_4900 = torch.constant.float 1.000000e+00 + %4129 = torch.aten.mul.Scalar %4128, %float1.000000e00_4900 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %4130 = torch.aten.reciprocal %4129 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_4901 = torch.constant.float 6.2831853071795862 + %4131 = torch.aten.mul.Scalar %4130, %float6.283190e00_4901 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_4902 = torch.constant.float 8.192000e+03 + %4132 = torch.aten.gt.Scalar %4131, %float8.192000e03_4902 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_4903 = torch.constant.int 8 + %4133 = torch.aten.div.Scalar %4129, %int8_4903 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %4134 = torch.aten.where.self %4132, %4133, %4129 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4135 = torch.aten.reciprocal %4131 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_4904 = torch.constant.int 8192 + %4136 = torch.aten.mul.Scalar %4135, %int8192_4904 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_4905 = torch.constant.int 1 - %4122 = torch.aten.slice.Tensor %4121, %int1_4902, %int0_4903, %int9223372036854775807_4904, %int1_4905 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4122, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_4906 = torch.constant.int 0 - %4123 = torch.aten.unsqueeze %4122, %int0_4906 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4123, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_4907 = torch.constant.int 1 - %int0_4908 = torch.constant.int 0 - %int9223372036854775807_4909 = torch.constant.int 9223372036854775807 - %int1_4910 = torch.constant.int 1 - %4124 = torch.aten.slice.Tensor %4123, %int1_4907, %int0_4908, %int9223372036854775807_4909, %int1_4910 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4124, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_4911 = torch.constant.int 2 - %int0_4912 = torch.constant.int 0 - %int9223372036854775807_4913 = torch.constant.int 9223372036854775807 - %int1_4914 = torch.constant.int 1 - %4125 = torch.aten.slice.Tensor %4124, %int2_4911, %int0_4912, %int9223372036854775807_4913, %int1_4914 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4125, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_4915 = torch.constant.int 4 + %int1_4906 = torch.constant.int 1 + %4137 = torch.aten.sub.Scalar %4136, %int1_4905, %int1_4906 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_4907 = torch.constant.int 3 + %4138 = torch.aten.div.Scalar %4137, %int3_4907 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_4908 = torch.constant.int 1 + %int1_4909 = torch.constant.int 1 + %4139 = torch.aten.rsub.Scalar %4138, %int1_4908, %int1_4909 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %4140 = torch.aten.mul.Tensor %4139, %4134 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_4910 = torch.constant.int 8 + %4141 = torch.aten.div.Scalar %4140, %int8_4910 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %4142 = torch.aten.mul.Tensor %4138, %4134 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_4911 = torch.constant.int 1 + %4143 = torch.aten.add.Tensor %4141, %4142, %int1_4911 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_4912 = torch.constant.float 2.048000e+03 + %4144 = torch.aten.lt.Scalar %4131, %float2.048000e03_4912 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %4145 = torch.aten.bitwise_not %4144 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_4913 = torch.constant.float 8.192000e+03 + %4146 = torch.aten.gt.Scalar %4131, %float8.192000e03_4913 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %4147 = torch.aten.bitwise_not %4146 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %4148 = torch.aten.mul.Tensor %4145, %4147 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %4149 = torch.aten.where.self %4148, %4143, %4134 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4150 = torch.prim.ListConstruct %4149, %4149 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_4914 = torch.constant.int -1 + %4151 = torch.aten.cat %4150, %int-1_4914 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_4915 = torch.constant.int 6 + %4152 = torch.prims.convert_element_type %4151, %int6_4915 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> %int1_4916 = torch.constant.int 1 - %int1_4917 = torch.constant.int 1 - %4126 = torch.prim.ListConstruct %int4_4915, %int1_4916, %int1_4917 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4127 = torch.aten.repeat %4125, %4126 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %4127, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_4918 = torch.constant.int 6 - %4128 = torch.prims.convert_element_type %4075, %int6_4918 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %4128, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %4129 = torch_c.to_builtin_tensor %4128 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %4130 = torch_c.to_builtin_tensor %4127 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %4131 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%4129, %4130) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %4132 = torch_c.from_builtin_tensor %4131 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %4132, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_4919 = torch.constant.int 5 - %4133 = torch.prims.convert_element_type %4132, %int5_4919 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4133, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_4920 = torch.constant.int 64 - %4134 = torch.aten.mul.Scalar %arg2, %int64_4920 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4134, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int38 = torch.constant.int 38 - %int1_4921 = torch.constant.int 1 - %4135 = torch.aten.add.Scalar %4134, %int38, %int1_4921 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4135, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_4922 = torch.constant.int 4 - %int32_4923 = torch.constant.int 32 - %int8_4924 = torch.constant.int 8 - %int128_4925 = torch.constant.int 128 - %4136 = torch.prim.ListConstruct %int4_4922, %398, %int32_4923, %int8_4924, %int128_4925 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4137 = torch.aten.view %4133, %4136 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4137, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_4926 = torch.constant.int 4 - %4138 = torch.aten.mul.int %int4_4926, %398 : !torch.int, !torch.int -> !torch.int - %int32_4927 = torch.constant.int 32 - %int8_4928 = torch.constant.int 8 - %int128_4929 = torch.constant.int 128 - %4139 = torch.prim.ListConstruct %4138, %int32_4927, %int8_4928, %int128_4929 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4140 = torch.aten.view %4137, %4139 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4140, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_4930 = torch.constant.int 4 - %4141 = torch.aten.mul.int %int4_4930, %398 : !torch.int, !torch.int -> !torch.int - %4142 = torch.prim.ListConstruct %4141 : (!torch.int) -> !torch.list - %4143 = torch.aten.view %4135, %4142 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4143, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_4931 = torch.constant.int 32 - %int2_4932 = torch.constant.int 2 - %int32_4933 = torch.constant.int 32 - %int8_4934 = torch.constant.int 8 - %int128_4935 = torch.constant.int 128 - %4144 = torch.prim.ListConstruct %389, %int32_4931, %int2_4932, %int32_4933, %int8_4934, %int128_4935 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4145 = torch.aten.view %3977, %4144 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4145, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4936 = torch.constant.int 32 - %4146 = torch.aten.mul.int %389, %int32_4936 : !torch.int, !torch.int -> !torch.int - %int2_4937 = torch.constant.int 2 - %4147 = torch.aten.mul.int %4146, %int2_4937 : !torch.int, !torch.int -> !torch.int - %int32_4938 = torch.constant.int 32 - %int8_4939 = torch.constant.int 8 - %int128_4940 = torch.constant.int 128 - %4148 = torch.prim.ListConstruct %4147, %int32_4938, %int8_4939, %int128_4940 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4149 = torch.aten.view %4145, %4148 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4149, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %4150 = torch.prim.ListConstruct %4143 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_4941 = torch.constant.bool false - %4151 = torch.aten.index_put %4149, %4150, %4140, %false_4941 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4151, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_4942 = torch.constant.int 32 - %int2_4943 = torch.constant.int 2 - %int32_4944 = torch.constant.int 32 - %int8_4945 = torch.constant.int 8 - %int128_4946 = torch.constant.int 128 - %4152 = torch.prim.ListConstruct %389, %int32_4942, %int2_4943, %int32_4944, %int8_4945, %int128_4946 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4153 = torch.aten.view %4151, %4152 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4153, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_4947 = torch.constant.int 2097152 - %4154 = torch.prim.ListConstruct %389, %int2097152_4947 : (!torch.int, !torch.int) -> !torch.list - %4155 = torch.aten.view %4153, %4154 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4155, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_4948 = torch.constant.int 32 - %int2_4949 = torch.constant.int 2 - %int32_4950 = torch.constant.int 32 - %int8_4951 = torch.constant.int 8 - %int128_4952 = torch.constant.int 128 - %4156 = torch.prim.ListConstruct %389, %int32_4948, %int2_4949, %int32_4950, %int8_4951, %int128_4952 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4157 = torch.aten.view %4155, %4156 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4157, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4953 = torch.constant.int 32 - %int8_4954 = torch.constant.int 8 - %int128_4955 = torch.constant.int 128 - %4158 = torch.prim.ListConstruct %4147, %int32_4953, %int8_4954, %int128_4955 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4159 = torch.aten.view %4157, %4158 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4159, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_4956 = torch.constant.int 4 - %int32_4957 = torch.constant.int 32 - %int8_4958 = torch.constant.int 8 - %int128_4959 = torch.constant.int 128 - %4160 = torch.prim.ListConstruct %int4_4956, %398, %int32_4957, %int8_4958, %int128_4959 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4161 = torch.aten.view %4077, %4160 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4161, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %4153 = torch.aten.unsqueeze %4123, %int1_4916 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_4917 = torch.constant.int 6 + %4154 = torch.prims.convert_element_type %4153, %int6_4917 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_4918 = torch.constant.int 0 + %4155 = torch.aten.unsqueeze %4152, %int0_4918 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_4919 = torch.constant.int 6 + %4156 = torch.prims.convert_element_type %4155, %int6_4919 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %4157 = torch.aten.mul.Tensor %4154, %4156 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %4158 = torch.aten.cos %4157 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_4920 = torch.constant.int 5 + %4159 = torch.prims.convert_element_type %4158, %int5_4920 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %4160 = torch.aten.sin %4157 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_4921 = torch.constant.int 5 + %4161 = torch.prims.convert_element_type %4160, %int5_4921 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_4922 = torch.constant.int 0 + %int0_4923 = torch.constant.int 0 + %int1_4924 = torch.constant.int 1 + %4162 = torch.aten.slice.Tensor %4159, %int0_4922, %int0_4923, %298, %int1_4924 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4162, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_4925 = torch.constant.int 1 + %int0_4926 = torch.constant.int 0 + %int9223372036854775807_4927 = torch.constant.int 9223372036854775807 + %int1_4928 = torch.constant.int 1 + %4163 = torch.aten.slice.Tensor %4162, %int1_4925, %int0_4926, %int9223372036854775807_4927, %int1_4928 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4163, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_4929 = torch.constant.int 0 + %int0_4930 = torch.constant.int 0 + %int1_4931 = torch.constant.int 1 + %4164 = torch.aten.slice.Tensor %4161, %int0_4929, %int0_4930, %298, %int1_4931 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4164, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_4932 = torch.constant.int 1 + %int0_4933 = torch.constant.int 0 + %int9223372036854775807_4934 = torch.constant.int 9223372036854775807 + %int1_4935 = torch.constant.int 1 + %4165 = torch.aten.slice.Tensor %4164, %int1_4932, %int0_4933, %int9223372036854775807_4934, %int1_4935 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4165, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_4936 = torch.constant.int 0 + %4166 = torch.aten.unsqueeze %4163, %int0_4936 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4166, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_4937 = torch.constant.int 1 + %int0_4938 = torch.constant.int 0 + %int9223372036854775807_4939 = torch.constant.int 9223372036854775807 + %int1_4940 = torch.constant.int 1 + %4167 = torch.aten.slice.Tensor %4166, %int1_4937, %int0_4938, %int9223372036854775807_4939, %int1_4940 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4167, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_4941 = torch.constant.int 2 + %4168 = torch.aten.unsqueeze %4167, %int2_4941 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4168, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_4942 = torch.constant.int 3 + %int0_4943 = torch.constant.int 0 + %int9223372036854775807_4944 = torch.constant.int 9223372036854775807 + %int1_4945 = torch.constant.int 1 + %4169 = torch.aten.slice.Tensor %4168, %int3_4942, %int0_4943, %int9223372036854775807_4944, %int1_4945 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4169, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_4946 = torch.constant.int 4 + %int1_4947 = torch.constant.int 1 + %int1_4948 = torch.constant.int 1 + %int1_4949 = torch.constant.int 1 + %4170 = torch.prim.ListConstruct %int4_4946, %int1_4947, %int1_4948, %int1_4949 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4171 = torch.aten.repeat %4169, %4170 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %4171, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_4950 = torch.constant.int 0 + %4172 = torch.aten.unsqueeze %4165, %int0_4950 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4172, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_4951 = torch.constant.int 1 + %int0_4952 = torch.constant.int 0 + %int9223372036854775807_4953 = torch.constant.int 9223372036854775807 + %int1_4954 = torch.constant.int 1 + %4173 = torch.aten.slice.Tensor %4172, %int1_4951, %int0_4952, %int9223372036854775807_4953, %int1_4954 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4173, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_4955 = torch.constant.int 2 + %4174 = torch.aten.unsqueeze %4173, %int2_4955 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4174, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_4956 = torch.constant.int 3 + %int0_4957 = torch.constant.int 0 + %int9223372036854775807_4958 = torch.constant.int 9223372036854775807 + %int1_4959 = torch.constant.int 1 + %4175 = torch.aten.slice.Tensor %4174, %int3_4956, %int0_4957, %int9223372036854775807_4958, %int1_4959 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4175, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> %int4_4960 = torch.constant.int 4 - %4162 = torch.aten.mul.int %int4_4960, %398 : !torch.int, !torch.int -> !torch.int - %int32_4961 = torch.constant.int 32 - %int8_4962 = torch.constant.int 8 - %int128_4963 = torch.constant.int 128 - %4163 = torch.prim.ListConstruct %4162, %int32_4961, %int8_4962, %int128_4963 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4164 = torch.aten.view %4161, %4163 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4164, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_4964 = torch.constant.int 1 - %int1_4965 = torch.constant.int 1 - %4165 = torch.aten.add.Scalar %4135, %int1_4964, %int1_4965 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4165, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_4966 = torch.constant.int 4 - %4166 = torch.aten.mul.int %int4_4966, %398 : !torch.int, !torch.int -> !torch.int - %4167 = torch.prim.ListConstruct %4166 : (!torch.int) -> !torch.list - %4168 = torch.aten.view %4165, %4167 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4168, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %4169 = torch.prim.ListConstruct %4168 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_4967 = torch.constant.bool false - %4170 = torch.aten.index_put %4159, %4169, %4164, %false_4967 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4170, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_4968 = torch.constant.int 32 - %int2_4969 = torch.constant.int 2 - %int32_4970 = torch.constant.int 32 - %int8_4971 = torch.constant.int 8 - %int128_4972 = torch.constant.int 128 - %4171 = torch.prim.ListConstruct %389, %int32_4968, %int2_4969, %int32_4970, %int8_4971, %int128_4972 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4172 = torch.aten.view %4170, %4171 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4172, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_4973 = torch.constant.int 2097152 - %4173 = torch.prim.ListConstruct %389, %int2097152_4973 : (!torch.int, !torch.int) -> !torch.list - %4174 = torch.aten.view %4172, %4173 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4174, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_4974 = torch.constant.int -2 - %4175 = torch.aten.unsqueeze %4133, %int-2_4974 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4175, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_4975 = torch.constant.int 4 - %int8_4976 = torch.constant.int 8 - %int4_4977 = torch.constant.int 4 - %int128_4978 = torch.constant.int 128 - %4176 = torch.prim.ListConstruct %int4_4975, %4118, %int8_4976, %int4_4977, %int128_4978 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4979 = torch.constant.bool false - %4177 = torch.aten.expand %4175, %4176, %false_4979 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4177, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_4980 = torch.constant.int 0 - %4178 = torch.aten.clone %4177, %int0_4980 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4178, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4981 = torch.constant.int 4 - %int32_4982 = torch.constant.int 32 - %int128_4983 = torch.constant.int 128 - %4179 = torch.prim.ListConstruct %int4_4981, %4118, %int32_4982, %int128_4983 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4180 = torch.aten._unsafe_view %4178, %4179 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4180, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_4984 = torch.constant.int -2 - %4181 = torch.aten.unsqueeze %4077, %int-2_4984 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4181, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_4985 = torch.constant.int 1 - %4182 = torch.aten.size.int %4071, %int1_4985 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_4986 = torch.constant.int 4 - %int8_4987 = torch.constant.int 8 - %int4_4988 = torch.constant.int 4 - %int128_4989 = torch.constant.int 128 - %4183 = torch.prim.ListConstruct %int4_4986, %4182, %int8_4987, %int4_4988, %int128_4989 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4990 = torch.constant.bool false - %4184 = torch.aten.expand %4181, %4183, %false_4990 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4184, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_4991 = torch.constant.int 0 - %4185 = torch.aten.clone %4184, %int0_4991 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4185, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4992 = torch.constant.int 4 - %int32_4993 = torch.constant.int 32 - %int128_4994 = torch.constant.int 128 - %4186 = torch.prim.ListConstruct %int4_4992, %4182, %int32_4993, %int128_4994 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4187 = torch.aten._unsafe_view %4185, %4186 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4187, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_4961 = torch.constant.int 1 + %int1_4962 = torch.constant.int 1 + %int1_4963 = torch.constant.int 1 + %4176 = torch.prim.ListConstruct %int4_4960, %int1_4961, %int1_4962, %int1_4963 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4177 = torch.aten.repeat %4175, %4176 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %4177, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %4178 = torch.aten.mul.Tensor %4118, %4171 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4178, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_4964 = torch.constant.int 3 + %int0_4965 = torch.constant.int 0 + %int64_4966 = torch.constant.int 64 + %int1_4967 = torch.constant.int 1 + %4179 = torch.aten.slice.Tensor %4118, %int3_4964, %int0_4965, %int64_4966, %int1_4967 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %4179, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_4968 = torch.constant.int 3 + %int64_4969 = torch.constant.int 64 + %int9223372036854775807_4970 = torch.constant.int 9223372036854775807 + %int1_4971 = torch.constant.int 1 + %4180 = torch.aten.slice.Tensor %4118, %int3_4968, %int64_4969, %int9223372036854775807_4970, %int1_4971 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %4180, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %4181 = torch.aten.neg %4180 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %4181, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %4182 = torch.prim.ListConstruct %4181, %4179 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_4972 = torch.constant.int -1 + %4183 = torch.aten.cat %4182, %int-1_4972 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4183, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %4184 = torch.aten.mul.Tensor %4183, %4177 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4184, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_4973 = torch.constant.int 1 + %4185 = torch.aten.add.Tensor %4178, %4184, %int1_4973 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4185, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_4974 = torch.constant.int 131072 + %none_4975 = torch.constant.none + %none_4976 = torch.constant.none + %cpu_4977 = torch.constant.device "cpu" + %false_4978 = torch.constant.bool false + %4186 = torch.aten.arange %int131072_4974, %none_4975, %none_4976, %cpu_4977, %false_4978 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_4979 = torch.constant.int 0 + %int128_4980 = torch.constant.int 128 + %int2_4981 = torch.constant.int 2 + %int4_4982 = torch.constant.int 4 + %none_4983 = torch.constant.none + %cpu_4984 = torch.constant.device "cpu" + %false_4985 = torch.constant.bool false + %4187 = torch.aten.arange.start_step %int0_4979, %int128_4980, %int2_4981, %int4_4982, %none_4983, %cpu_4984, %false_4985 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_4986 = torch.constant.int 6 + %4188 = torch.prims.convert_element_type %4187, %int6_4986 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_4987 = torch.constant.int 128 + %4189 = torch.aten.div.Scalar %4188, %int128_4987 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_4988 = torch.constant.float 5.000000e+05 + %4190 = torch.aten.pow.Scalar %float5.000000e05_4988, %4189 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4191 = torch.aten.reciprocal %4190 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_4989 = torch.constant.float 1.000000e+00 + %4192 = torch.aten.mul.Scalar %4191, %float1.000000e00_4989 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %4193 = torch.aten.reciprocal %4192 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_4990 = torch.constant.float 6.2831853071795862 + %4194 = torch.aten.mul.Scalar %4193, %float6.283190e00_4990 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_4991 = torch.constant.float 8.192000e+03 + %4195 = torch.aten.gt.Scalar %4194, %float8.192000e03_4991 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_4992 = torch.constant.int 8 + %4196 = torch.aten.div.Scalar %4192, %int8_4992 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %4197 = torch.aten.where.self %4195, %4196, %4192 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4198 = torch.aten.reciprocal %4194 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_4993 = torch.constant.int 8192 + %4199 = torch.aten.mul.Scalar %4198, %int8192_4993 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_4994 = torch.constant.int 1 %int1_4995 = torch.constant.int 1 - %int2_4996 = torch.constant.int 2 - %4188 = torch.aten.transpose.int %4105, %int1_4995, %int2_4996 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4188, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %4200 = torch.aten.sub.Scalar %4199, %int1_4994, %int1_4995 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_4996 = torch.constant.int 3 + %4201 = torch.aten.div.Scalar %4200, %int3_4996 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_4997 = torch.constant.int 1 - %int2_4998 = torch.constant.int 2 - %4189 = torch.aten.transpose.int %4180, %int1_4997, %int2_4998 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4189, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_4999 = torch.constant.int 1 - %int2_5000 = torch.constant.int 2 - %4190 = torch.aten.transpose.int %4187, %int1_4999, %int2_5000 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4190, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_5001 = torch.constant.float 0.000000e+00 - %true_5002 = torch.constant.bool true - %none_5003 = torch.constant.none - %none_5004 = torch.constant.none - %4191:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4188, %4189, %4190, %float0.000000e00_5001, %true_5002, %none_5003, %none_5004) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %4191#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_4998 = torch.constant.int 1 + %4202 = torch.aten.rsub.Scalar %4201, %int1_4997, %int1_4998 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %4203 = torch.aten.mul.Tensor %4202, %4197 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_4999 = torch.constant.int 8 + %4204 = torch.aten.div.Scalar %4203, %int8_4999 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %4205 = torch.aten.mul.Tensor %4201, %4197 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_5000 = torch.constant.int 1 + %4206 = torch.aten.add.Tensor %4204, %4205, %int1_5000 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_5001 = torch.constant.float 2.048000e+03 + %4207 = torch.aten.lt.Scalar %4194, %float2.048000e03_5001 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %4208 = torch.aten.bitwise_not %4207 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_5002 = torch.constant.float 8.192000e+03 + %4209 = torch.aten.gt.Scalar %4194, %float8.192000e03_5002 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %4210 = torch.aten.bitwise_not %4209 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %4211 = torch.aten.mul.Tensor %4208, %4210 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %4212 = torch.aten.where.self %4211, %4206, %4197 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4213 = torch.prim.ListConstruct %4212, %4212 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_5003 = torch.constant.int -1 + %4214 = torch.aten.cat %4213, %int-1_5003 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_5004 = torch.constant.int 6 + %4215 = torch.prims.convert_element_type %4214, %int6_5004 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> %int1_5005 = torch.constant.int 1 - %int2_5006 = torch.constant.int 2 - %4192 = torch.aten.transpose.int %4191#0, %int1_5005, %int2_5006 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4192, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_5007 = torch.constant.int 4 - %int4096_5008 = torch.constant.int 4096 - %4193 = torch.prim.ListConstruct %int4_5007, %4090, %int4096_5008 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4194 = torch.aten.view %4192, %4193 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4194, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5009 = torch.constant.int -2 - %int-1_5010 = torch.constant.int -1 - %4195 = torch.aten.transpose.int %176, %int-2_5009, %int-1_5010 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_5011 = torch.constant.int 4 - %4196 = torch.aten.mul.int %int4_5011, %4090 : !torch.int, !torch.int -> !torch.int - %int4096_5012 = torch.constant.int 4096 - %4197 = torch.prim.ListConstruct %4196, %int4096_5012 : (!torch.int, !torch.int) -> !torch.list - %4198 = torch.aten.view %4194, %4197 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4198, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4199 = torch.aten.mm %4198, %4195 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4199, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_5013 = torch.constant.int 4 - %int4096_5014 = torch.constant.int 4096 - %4200 = torch.prim.ListConstruct %int4_5013, %4090, %int4096_5014 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4201 = torch.aten.view %4199, %4200 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4201, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_5015 = torch.constant.int 1 - %4202 = torch.aten.add.Tensor %4040, %4201, %int1_5015 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4202, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_5016 = torch.constant.int 6 - %4203 = torch.prims.convert_element_type %4202, %int6_5016 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4203, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_5017 = torch.constant.int 2 - %4204 = torch.aten.pow.Tensor_Scalar %4203, %int2_5017 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4204, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_5018 = torch.constant.int -1 - %4205 = torch.prim.ListConstruct %int-1_5018 : (!torch.int) -> !torch.list - %true_5019 = torch.constant.bool true - %none_5020 = torch.constant.none - %4206 = torch.aten.mean.dim %4204, %4205, %true_5019, %none_5020 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4206, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_5021 = torch.constant.float 9.9999997473787516E-6 - %int1_5022 = torch.constant.int 1 - %4207 = torch.aten.add.Scalar %4206, %float9.999990e-06_5021, %int1_5022 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4207, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4208 = torch.aten.rsqrt %4207 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4208, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4209 = torch.aten.mul.Tensor %4203, %4208 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4209, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_5023 = torch.constant.int 5 - %4210 = torch.prims.convert_element_type %4209, %int5_5023 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4210, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %4211 = torch.aten.mul.Tensor %177, %4210 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4211, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_5024 = torch.constant.int 5 - %4212 = torch.prims.convert_element_type %4211, %int5_5024 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4212, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5025 = torch.constant.int -2 - %int-1_5026 = torch.constant.int -1 - %4213 = torch.aten.transpose.int %178, %int-2_5025, %int-1_5026 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_5027 = torch.constant.int 4 - %4214 = torch.aten.mul.int %int4_5027, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5028 = torch.constant.int 4096 - %4215 = torch.prim.ListConstruct %4214, %int4096_5028 : (!torch.int, !torch.int) -> !torch.list - %4216 = torch.aten.view %4212, %4215 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4216, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4217 = torch.aten.mm %4216, %4213 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4217, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_5029 = torch.constant.int 4 - %int14336_5030 = torch.constant.int 14336 - %4218 = torch.prim.ListConstruct %int4_5029, %306, %int14336_5030 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4219 = torch.aten.view %4217, %4218 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4219, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %4220 = torch.aten.silu %4219 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4220, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_5031 = torch.constant.int -2 - %int-1_5032 = torch.constant.int -1 - %4221 = torch.aten.transpose.int %179, %int-2_5031, %int-1_5032 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_5033 = torch.constant.int 4 - %4222 = torch.aten.mul.int %int4_5033, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5034 = torch.constant.int 4096 - %4223 = torch.prim.ListConstruct %4222, %int4096_5034 : (!torch.int, !torch.int) -> !torch.list - %4224 = torch.aten.view %4212, %4223 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4224, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4225 = torch.aten.mm %4224, %4221 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4225, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %4216 = torch.aten.unsqueeze %4186, %int1_5005 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_5006 = torch.constant.int 6 + %4217 = torch.prims.convert_element_type %4216, %int6_5006 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_5007 = torch.constant.int 0 + %4218 = torch.aten.unsqueeze %4215, %int0_5007 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_5008 = torch.constant.int 6 + %4219 = torch.prims.convert_element_type %4218, %int6_5008 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %4220 = torch.aten.mul.Tensor %4217, %4219 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %4221 = torch.aten.cos %4220 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_5009 = torch.constant.int 5 + %4222 = torch.prims.convert_element_type %4221, %int5_5009 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %4223 = torch.aten.sin %4220 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_5010 = torch.constant.int 5 + %4224 = torch.prims.convert_element_type %4223, %int5_5010 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_5011 = torch.constant.int 0 + %int0_5012 = torch.constant.int 0 + %int1_5013 = torch.constant.int 1 + %4225 = torch.aten.slice.Tensor %4222, %int0_5011, %int0_5012, %298, %int1_5013 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4225, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_5014 = torch.constant.int 1 + %int0_5015 = torch.constant.int 0 + %int9223372036854775807_5016 = torch.constant.int 9223372036854775807 + %int1_5017 = torch.constant.int 1 + %4226 = torch.aten.slice.Tensor %4225, %int1_5014, %int0_5015, %int9223372036854775807_5016, %int1_5017 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4226, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_5018 = torch.constant.int 0 + %int0_5019 = torch.constant.int 0 + %int1_5020 = torch.constant.int 1 + %4227 = torch.aten.slice.Tensor %4224, %int0_5018, %int0_5019, %298, %int1_5020 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4227, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_5021 = torch.constant.int 1 + %int0_5022 = torch.constant.int 0 + %int9223372036854775807_5023 = torch.constant.int 9223372036854775807 + %int1_5024 = torch.constant.int 1 + %4228 = torch.aten.slice.Tensor %4227, %int1_5021, %int0_5022, %int9223372036854775807_5023, %int1_5024 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4228, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_5025 = torch.constant.int 0 + %4229 = torch.aten.unsqueeze %4226, %int0_5025 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4229, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_5026 = torch.constant.int 1 + %int0_5027 = torch.constant.int 0 + %int9223372036854775807_5028 = torch.constant.int 9223372036854775807 + %int1_5029 = torch.constant.int 1 + %4230 = torch.aten.slice.Tensor %4229, %int1_5026, %int0_5027, %int9223372036854775807_5028, %int1_5029 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4230, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_5030 = torch.constant.int 2 + %4231 = torch.aten.unsqueeze %4230, %int2_5030 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4231, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_5031 = torch.constant.int 3 + %int0_5032 = torch.constant.int 0 + %int9223372036854775807_5033 = torch.constant.int 9223372036854775807 + %int1_5034 = torch.constant.int 1 + %4232 = torch.aten.slice.Tensor %4231, %int3_5031, %int0_5032, %int9223372036854775807_5033, %int1_5034 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4232, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> %int4_5035 = torch.constant.int 4 - %int14336_5036 = torch.constant.int 14336 - %4226 = torch.prim.ListConstruct %int4_5035, %306, %int14336_5036 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4227 = torch.aten.view %4225, %4226 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4227, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %4228 = torch.aten.mul.Tensor %4220, %4227 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4228, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_5037 = torch.constant.int -2 - %int-1_5038 = torch.constant.int -1 - %4229 = torch.aten.transpose.int %180, %int-2_5037, %int-1_5038 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_5039 = torch.constant.int 1 - %4230 = torch.aten.size.int %4219, %int1_5039 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_5040 = torch.constant.int 4 - %4231 = torch.aten.mul.int %int4_5040, %4230 : !torch.int, !torch.int -> !torch.int - %int14336_5041 = torch.constant.int 14336 - %4232 = torch.prim.ListConstruct %4231, %int14336_5041 : (!torch.int, !torch.int) -> !torch.list - %4233 = torch.aten.view %4228, %4232 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4233, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %4234 = torch.aten.mm %4233, %4229 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4234, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_5042 = torch.constant.int 4 - %int4096_5043 = torch.constant.int 4096 - %4235 = torch.prim.ListConstruct %int4_5042, %4230, %int4096_5043 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4236 = torch.aten.view %4234, %4235 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4236, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_5044 = torch.constant.int 1 - %4237 = torch.aten.add.Tensor %4202, %4236, %int1_5044 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4237, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_5045 = torch.constant.int 6 - %4238 = torch.prims.convert_element_type %4237, %int6_5045 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4238, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_5046 = torch.constant.int 2 - %4239 = torch.aten.pow.Tensor_Scalar %4238, %int2_5046 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4239, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_5047 = torch.constant.int -1 - %4240 = torch.prim.ListConstruct %int-1_5047 : (!torch.int) -> !torch.list - %true_5048 = torch.constant.bool true - %none_5049 = torch.constant.none - %4241 = torch.aten.mean.dim %4239, %4240, %true_5048, %none_5049 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4241, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_5050 = torch.constant.float 9.9999997473787516E-6 + %int1_5036 = torch.constant.int 1 + %int1_5037 = torch.constant.int 1 + %int1_5038 = torch.constant.int 1 + %4233 = torch.prim.ListConstruct %int4_5035, %int1_5036, %int1_5037, %int1_5038 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4234 = torch.aten.repeat %4232, %4233 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %4234, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_5039 = torch.constant.int 0 + %4235 = torch.aten.unsqueeze %4228, %int0_5039 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4235, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_5040 = torch.constant.int 1 + %int0_5041 = torch.constant.int 0 + %int9223372036854775807_5042 = torch.constant.int 9223372036854775807 + %int1_5043 = torch.constant.int 1 + %4236 = torch.aten.slice.Tensor %4235, %int1_5040, %int0_5041, %int9223372036854775807_5042, %int1_5043 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4236, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_5044 = torch.constant.int 2 + %4237 = torch.aten.unsqueeze %4236, %int2_5044 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4237, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_5045 = torch.constant.int 3 + %int0_5046 = torch.constant.int 0 + %int9223372036854775807_5047 = torch.constant.int 9223372036854775807 + %int1_5048 = torch.constant.int 1 + %4238 = torch.aten.slice.Tensor %4237, %int3_5045, %int0_5046, %int9223372036854775807_5047, %int1_5048 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4238, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_5049 = torch.constant.int 4 + %int1_5050 = torch.constant.int 1 %int1_5051 = torch.constant.int 1 - %4242 = torch.aten.add.Scalar %4241, %float9.999990e-06_5050, %int1_5051 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4242, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4243 = torch.aten.rsqrt %4242 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4243, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4244 = torch.aten.mul.Tensor %4238, %4243 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4244, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_5052 = torch.constant.int 5 - %4245 = torch.prims.convert_element_type %4244, %int5_5052 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4245, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %4246 = torch.aten.mul.Tensor %181, %4245 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4246, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_5053 = torch.constant.int 5 - %4247 = torch.prims.convert_element_type %4246, %int5_5053 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4247, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5054 = torch.constant.int -2 - %int-1_5055 = torch.constant.int -1 - %4248 = torch.aten.transpose.int %182, %int-2_5054, %int-1_5055 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_5056 = torch.constant.int 4 - %4249 = torch.aten.mul.int %int4_5056, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5057 = torch.constant.int 4096 - %4250 = torch.prim.ListConstruct %4249, %int4096_5057 : (!torch.int, !torch.int) -> !torch.list - %4251 = torch.aten.view %4247, %4250 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4251, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4252 = torch.aten.mm %4251, %4248 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4252, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_5058 = torch.constant.int 4 - %int4096_5059 = torch.constant.int 4096 - %4253 = torch.prim.ListConstruct %int4_5058, %306, %int4096_5059 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4254 = torch.aten.view %4252, %4253 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4254, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5060 = torch.constant.int -2 + %int1_5052 = torch.constant.int 1 + %4239 = torch.prim.ListConstruct %int4_5049, %int1_5050, %int1_5051, %int1_5052 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4240 = torch.aten.repeat %4238, %4239 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %4240, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %4241 = torch.aten.mul.Tensor %4120, %4234 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4241, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_5053 = torch.constant.int 3 + %int0_5054 = torch.constant.int 0 + %int64_5055 = torch.constant.int 64 + %int1_5056 = torch.constant.int 1 + %4242 = torch.aten.slice.Tensor %4120, %int3_5053, %int0_5054, %int64_5055, %int1_5056 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %4242, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_5057 = torch.constant.int 3 + %int64_5058 = torch.constant.int 64 + %int9223372036854775807_5059 = torch.constant.int 9223372036854775807 + %int1_5060 = torch.constant.int 1 + %4243 = torch.aten.slice.Tensor %4120, %int3_5057, %int64_5058, %int9223372036854775807_5059, %int1_5060 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %4243, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %4244 = torch.aten.neg %4243 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %4244, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %4245 = torch.prim.ListConstruct %4244, %4242 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list %int-1_5061 = torch.constant.int -1 - %4255 = torch.aten.transpose.int %183, %int-2_5060, %int-1_5061 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_5062 = torch.constant.int 4 - %4256 = torch.aten.mul.int %int4_5062, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5063 = torch.constant.int 4096 - %4257 = torch.prim.ListConstruct %4256, %int4096_5063 : (!torch.int, !torch.int) -> !torch.list - %4258 = torch.aten.view %4247, %4257 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4258, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4259 = torch.aten.mm %4258, %4255 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %4259, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_5064 = torch.constant.int 4 - %int1024_5065 = torch.constant.int 1024 - %4260 = torch.prim.ListConstruct %int4_5064, %306, %int1024_5065 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4261 = torch.aten.view %4259, %4260 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %4261, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_5066 = torch.constant.int -2 - %int-1_5067 = torch.constant.int -1 - %4262 = torch.aten.transpose.int %184, %int-2_5066, %int-1_5067 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %4246 = torch.aten.cat %4245, %int-1_5061 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4246, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %4247 = torch.aten.mul.Tensor %4246, %4240 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4247, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_5062 = torch.constant.int 1 + %4248 = torch.aten.add.Tensor %4241, %4247, %int1_5062 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4248, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_5063 = torch.constant.int 32 + %4249 = torch.aten.mul.Scalar %arg2, %int32_5063 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4249, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int14 = torch.constant.int 14 + %int1_5064 = torch.constant.int 1 + %4250 = torch.aten.add.Scalar %4249, %int14, %int1_5064 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4250, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_5065 = torch.constant.int 2 + %4251 = torch.aten.mul.Scalar %4250, %int2_5065 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4251, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_5066 = torch.constant.int 0 + %int1_5067 = torch.constant.int 1 + %4252 = torch.aten.add.Scalar %4251, %int0_5066, %int1_5067 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4252, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %4253 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %4254 = torch.aten.view %4252, %4253 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %4254, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> %int4_5068 = torch.constant.int 4 - %4263 = torch.aten.mul.int %int4_5068, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5069 = torch.constant.int 4096 - %4264 = torch.prim.ListConstruct %4263, %int4096_5069 : (!torch.int, !torch.int) -> !torch.list - %4265 = torch.aten.view %4247, %4264 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4265, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4266 = torch.aten.mm %4265, %4262 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %4266, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_5070 = torch.constant.int 4 - %int1024_5071 = torch.constant.int 1024 - %4267 = torch.prim.ListConstruct %int4_5070, %306, %int1024_5071 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4268 = torch.aten.view %4266, %4267 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %4268, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_5072 = torch.constant.int 4 - %int32_5073 = torch.constant.int 32 + %int32_5069 = torch.constant.int 32 + %int8_5070 = torch.constant.int 8 + %int128_5071 = torch.constant.int 128 + %4255 = torch.prim.ListConstruct %int4_5068, %296, %int32_5069, %int8_5070, %int128_5071 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4256 = torch.aten.view %4248, %4255 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4256, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_5072 = torch.constant.int 32 + %int8_5073 = torch.constant.int 8 %int128_5074 = torch.constant.int 128 - %4269 = torch.prim.ListConstruct %int4_5072, %306, %int32_5073, %int128_5074 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4270 = torch.aten.view %4254, %4269 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4270, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_5075 = torch.constant.int 4 - %int8_5076 = torch.constant.int 8 - %int128_5077 = torch.constant.int 128 - %4271 = torch.prim.ListConstruct %int4_5075, %306, %int8_5076, %int128_5077 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4272 = torch.aten.view %4261, %4271 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4272, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_5078 = torch.constant.int 4 - %int8_5079 = torch.constant.int 8 - %int128_5080 = torch.constant.int 128 - %4273 = torch.prim.ListConstruct %int4_5078, %306, %int8_5079, %int128_5080 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4274 = torch.aten.view %4268, %4273 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4274, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_5081 = torch.constant.int 131072 - %none_5082 = torch.constant.none - %none_5083 = torch.constant.none - %cpu_5084 = torch.constant.device "cpu" - %false_5085 = torch.constant.bool false - %4275 = torch.aten.arange %int131072_5081, %none_5082, %none_5083, %cpu_5084, %false_5085 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_5086 = torch.constant.int 0 - %int128_5087 = torch.constant.int 128 - %none_5088 = torch.constant.none - %none_5089 = torch.constant.none - %cpu_5090 = torch.constant.device "cpu" - %false_5091 = torch.constant.bool false - %4276 = torch.aten.arange.start %int0_5086, %int128_5087, %none_5088, %none_5089, %cpu_5090, %false_5091 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_5092 = torch.constant.int 2 - %4277 = torch.aten.floor_divide.Scalar %4276, %int2_5092 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_5093 = torch.constant.int 6 - %4278 = torch.prims.convert_element_type %4277, %int6_5093 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_5094 = torch.constant.int 128 - %4279 = torch.aten.div.Scalar %4278, %int128_5094 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_5095 = torch.constant.float 2.000000e+00 - %4280 = torch.aten.mul.Scalar %4279, %float2.000000e00_5095 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_5096 = torch.constant.float 5.000000e+05 - %4281 = torch.aten.pow.Scalar %float5.000000e05_5096, %4280 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %4282 = torch.aten.reciprocal %4281 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_5097 = torch.constant.float 1.000000e+00 - %4283 = torch.aten.mul.Scalar %4282, %float1.000000e00_5097 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_5098 = torch.constant.int 1 - %4284 = torch.aten.unsqueeze %4275, %int1_5098 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_5099 = torch.constant.int 0 - %4285 = torch.aten.unsqueeze %4283, %int0_5099 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %4286 = torch.aten.mul.Tensor %4284, %4285 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_5100 = torch.constant.int 1 - %4287 = torch.aten.size.int %4254, %int1_5100 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_5101 = torch.constant.int 0 - %4288 = torch.aten.add.int %int0_5101, %4287 : !torch.int, !torch.int -> !torch.int - %int0_5102 = torch.constant.int 0 - %int0_5103 = torch.constant.int 0 - %int1_5104 = torch.constant.int 1 - %4289 = torch.aten.slice.Tensor %4286, %int0_5102, %int0_5103, %4288, %int1_5104 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4289, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %4257 = torch.prim.ListConstruct %504, %int32_5072, %int8_5073, %int128_5074 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4258 = torch.aten.view %4256, %4257 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %4258, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_5075 = torch.constant.int 1 + %int2_5076 = torch.constant.int 2 + %4259 = torch.aten.transpose.int %4258, %int1_5075, %int2_5076 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4259, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_5077 = torch.constant.int 5 + %4260 = torch.prims.convert_element_type %4259, %int5_5077 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4260, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_5078 = torch.constant.int 32 + %int2_5079 = torch.constant.int 2 + %int8_5080 = torch.constant.int 8 + %int32_5081 = torch.constant.int 32 + %int128_5082 = torch.constant.int 128 + %4261 = torch.prim.ListConstruct %297, %int32_5078, %int2_5079, %int8_5080, %int32_5081, %int128_5082 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4262 = torch.aten.view %4024, %4261 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4262, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_5083 = torch.constant.int 8 + %int32_5084 = torch.constant.int 32 + %int128_5085 = torch.constant.int 128 + %4263 = torch.prim.ListConstruct %497, %int8_5083, %int32_5084, %int128_5085 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4264 = torch.aten.view %4262, %4263 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4264, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %4265 = torch.prim.ListConstruct %4254 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_5086 = torch.constant.bool false + %4266 = torch.aten.index_put %4264, %4265, %4260, %false_5086 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4266, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_5087 = torch.constant.int 32 + %int2_5088 = torch.constant.int 2 + %int8_5089 = torch.constant.int 8 + %int32_5090 = torch.constant.int 32 + %int128_5091 = torch.constant.int 128 + %4267 = torch.prim.ListConstruct %297, %int32_5087, %int2_5088, %int8_5089, %int32_5090, %int128_5091 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4268 = torch.aten.view %4266, %4267 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4268, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_5092 = torch.constant.int 2097152 + %4269 = torch.prim.ListConstruct %297, %int2097152_5092 : (!torch.int, !torch.int) -> !torch.list + %4270 = torch.aten.view %4268, %4269 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4270, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_5093 = torch.constant.int 32 + %int2_5094 = torch.constant.int 2 + %int8_5095 = torch.constant.int 8 + %int32_5096 = torch.constant.int 32 + %int128_5097 = torch.constant.int 128 + %4271 = torch.prim.ListConstruct %297, %int32_5093, %int2_5094, %int8_5095, %int32_5096, %int128_5097 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4272 = torch.aten.view %4270, %4271 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4272, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_5098 = torch.constant.int 8 + %int32_5099 = torch.constant.int 32 + %int128_5100 = torch.constant.int 128 + %4273 = torch.prim.ListConstruct %497, %int8_5098, %int32_5099, %int128_5100 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4274 = torch.aten.view %4272, %4273 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4274, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_5101 = torch.constant.int 32 + %4275 = torch.aten.mul.Scalar %arg2, %int32_5101 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4275, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int14_5102 = torch.constant.int 14 + %int1_5103 = torch.constant.int 1 + %4276 = torch.aten.add.Scalar %4275, %int14_5102, %int1_5103 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4276, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_5104 = torch.constant.int 2 + %4277 = torch.aten.mul.Scalar %4276, %int2_5104 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4277, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> %int1_5105 = torch.constant.int 1 - %int0_5106 = torch.constant.int 0 - %int9223372036854775807_5107 = torch.constant.int 9223372036854775807 - %int1_5108 = torch.constant.int 1 - %4290 = torch.aten.slice.Tensor %4289, %int1_5105, %int0_5106, %int9223372036854775807_5107, %int1_5108 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4290, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_5109 = torch.constant.int 1 - %int0_5110 = torch.constant.int 0 - %int9223372036854775807_5111 = torch.constant.int 9223372036854775807 - %int1_5112 = torch.constant.int 1 - %4291 = torch.aten.slice.Tensor %4290, %int1_5109, %int0_5110, %int9223372036854775807_5111, %int1_5112 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4291, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_5113 = torch.constant.int 0 - %4292 = torch.aten.unsqueeze %4291, %int0_5113 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4292, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %int1_5106 = torch.constant.int 1 + %4278 = torch.aten.add.Scalar %4277, %int1_5105, %int1_5106 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4278, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %4279 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %4280 = torch.aten.view %4278, %4279 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %4280, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_5107 = torch.constant.int 4 + %int32_5108 = torch.constant.int 32 + %int8_5109 = torch.constant.int 8 + %int128_5110 = torch.constant.int 128 + %4281 = torch.prim.ListConstruct %int4_5107, %296, %int32_5108, %int8_5109, %int128_5110 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4282 = torch.aten.view %4122, %4281 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4282, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_5111 = torch.constant.int 32 + %int8_5112 = torch.constant.int 8 + %int128_5113 = torch.constant.int 128 + %4283 = torch.prim.ListConstruct %504, %int32_5111, %int8_5112, %int128_5113 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4284 = torch.aten.view %4282, %4283 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %4284, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> %int1_5114 = torch.constant.int 1 - %int0_5115 = torch.constant.int 0 - %int9223372036854775807_5116 = torch.constant.int 9223372036854775807 - %int1_5117 = torch.constant.int 1 - %4293 = torch.aten.slice.Tensor %4292, %int1_5114, %int0_5115, %int9223372036854775807_5116, %int1_5117 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4293, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_5118 = torch.constant.int 2 - %int0_5119 = torch.constant.int 0 - %int9223372036854775807_5120 = torch.constant.int 9223372036854775807 - %int1_5121 = torch.constant.int 1 - %4294 = torch.aten.slice.Tensor %4293, %int2_5118, %int0_5119, %int9223372036854775807_5120, %int1_5121 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4294, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_5122 = torch.constant.int 4 - %int1_5123 = torch.constant.int 1 - %int1_5124 = torch.constant.int 1 - %4295 = torch.prim.ListConstruct %int4_5122, %int1_5123, %int1_5124 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4296 = torch.aten.repeat %4294, %4295 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %4296, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_5125 = torch.constant.int 6 - %4297 = torch.prims.convert_element_type %4270, %int6_5125 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %4297, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %4298 = torch_c.to_builtin_tensor %4297 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %4299 = torch_c.to_builtin_tensor %4296 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %4300 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%4298, %4299) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %4301 = torch_c.from_builtin_tensor %4300 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %4301, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_5126 = torch.constant.int 5 - %4302 = torch.prims.convert_element_type %4301, %int5_5126 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4302, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_5127 = torch.constant.int 131072 - %none_5128 = torch.constant.none - %none_5129 = torch.constant.none - %cpu_5130 = torch.constant.device "cpu" - %false_5131 = torch.constant.bool false - %4303 = torch.aten.arange %int131072_5127, %none_5128, %none_5129, %cpu_5130, %false_5131 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_5132 = torch.constant.int 0 + %int2_5115 = torch.constant.int 2 + %4285 = torch.aten.transpose.int %4284, %int1_5114, %int2_5115 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4285, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_5116 = torch.constant.int 5 + %4286 = torch.prims.convert_element_type %4285, %int5_5116 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4286, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %4287 = torch.prim.ListConstruct %4280 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_5117 = torch.constant.bool false + %4288 = torch.aten.index_put %4274, %4287, %4286, %false_5117 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4288, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_5118 = torch.constant.int 32 + %int2_5119 = torch.constant.int 2 + %int8_5120 = torch.constant.int 8 + %int32_5121 = torch.constant.int 32 + %int128_5122 = torch.constant.int 128 + %4289 = torch.prim.ListConstruct %297, %int32_5118, %int2_5119, %int8_5120, %int32_5121, %int128_5122 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4290 = torch.aten.view %4288, %4289 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4290, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_5123 = torch.constant.int 2097152 + %4291 = torch.prim.ListConstruct %297, %int2097152_5123 : (!torch.int, !torch.int) -> !torch.list + %4292 = torch.aten.view %4290, %4291 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4292, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_5124 = torch.constant.int -2 + %4293 = torch.aten.unsqueeze %4248, %int-2_5124 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4293, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_5125 = torch.constant.int 4 + %int8_5126 = torch.constant.int 8 + %int4_5127 = torch.constant.int 4 + %int128_5128 = torch.constant.int 128 + %4294 = torch.prim.ListConstruct %int4_5125, %298, %int8_5126, %int4_5127, %int128_5128 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_5129 = torch.constant.bool false + %4295 = torch.aten.expand %4293, %4294, %false_5129 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4295, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_5130 = torch.constant.int 0 + %4296 = torch.aten.clone %4295, %int0_5130 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4296, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_5131 = torch.constant.int 4 + %int32_5132 = torch.constant.int 32 %int128_5133 = torch.constant.int 128 - %none_5134 = torch.constant.none - %none_5135 = torch.constant.none - %cpu_5136 = torch.constant.device "cpu" - %false_5137 = torch.constant.bool false - %4304 = torch.aten.arange.start %int0_5132, %int128_5133, %none_5134, %none_5135, %cpu_5136, %false_5137 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_5138 = torch.constant.int 2 - %4305 = torch.aten.floor_divide.Scalar %4304, %int2_5138 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_5139 = torch.constant.int 6 - %4306 = torch.prims.convert_element_type %4305, %int6_5139 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_5140 = torch.constant.int 128 - %4307 = torch.aten.div.Scalar %4306, %int128_5140 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_5141 = torch.constant.float 2.000000e+00 - %4308 = torch.aten.mul.Scalar %4307, %float2.000000e00_5141 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_5142 = torch.constant.float 5.000000e+05 - %4309 = torch.aten.pow.Scalar %float5.000000e05_5142, %4308 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %4310 = torch.aten.reciprocal %4309 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_5143 = torch.constant.float 1.000000e+00 - %4311 = torch.aten.mul.Scalar %4310, %float1.000000e00_5143 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %4297 = torch.prim.ListConstruct %int4_5131, %298, %int32_5132, %int128_5133 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4298 = torch.aten._unsafe_view %4296, %4297 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4298, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_5134 = torch.constant.int -2 + %4299 = torch.aten.unsqueeze %4122, %int-2_5134 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4299, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_5135 = torch.constant.int 4 + %int8_5136 = torch.constant.int 8 + %int4_5137 = torch.constant.int 4 + %int128_5138 = torch.constant.int 128 + %4300 = torch.prim.ListConstruct %int4_5135, %298, %int8_5136, %int4_5137, %int128_5138 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_5139 = torch.constant.bool false + %4301 = torch.aten.expand %4299, %4300, %false_5139 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4301, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_5140 = torch.constant.int 0 + %4302 = torch.aten.clone %4301, %int0_5140 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4302, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_5141 = torch.constant.int 4 + %int32_5142 = torch.constant.int 32 + %int128_5143 = torch.constant.int 128 + %4303 = torch.prim.ListConstruct %int4_5141, %298, %int32_5142, %int128_5143 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4304 = torch.aten._unsafe_view %4302, %4303 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4304, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int1_5144 = torch.constant.int 1 - %4312 = torch.aten.unsqueeze %4303, %int1_5144 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_5145 = torch.constant.int 0 - %4313 = torch.aten.unsqueeze %4311, %int0_5145 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %4314 = torch.aten.mul.Tensor %4312, %4313 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %int2_5145 = torch.constant.int 2 + %4305 = torch.aten.transpose.int %4185, %int1_5144, %int2_5145 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4305, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_5146 = torch.constant.int 1 - %4315 = torch.aten.size.int %4261, %int1_5146 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_5147 = torch.constant.int 0 - %4316 = torch.aten.add.int %int0_5147, %4315 : !torch.int, !torch.int -> !torch.int - %int0_5148 = torch.constant.int 0 - %int0_5149 = torch.constant.int 0 - %int1_5150 = torch.constant.int 1 - %4317 = torch.aten.slice.Tensor %4314, %int0_5148, %int0_5149, %4316, %int1_5150 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4317, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_5151 = torch.constant.int 1 - %int0_5152 = torch.constant.int 0 - %int9223372036854775807_5153 = torch.constant.int 9223372036854775807 - %int1_5154 = torch.constant.int 1 - %4318 = torch.aten.slice.Tensor %4317, %int1_5151, %int0_5152, %int9223372036854775807_5153, %int1_5154 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4318, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_5155 = torch.constant.int 1 - %int0_5156 = torch.constant.int 0 - %int9223372036854775807_5157 = torch.constant.int 9223372036854775807 - %int1_5158 = torch.constant.int 1 - %4319 = torch.aten.slice.Tensor %4318, %int1_5155, %int0_5156, %int9223372036854775807_5157, %int1_5158 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4319, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_5159 = torch.constant.int 0 - %4320 = torch.aten.unsqueeze %4319, %int0_5159 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4320, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_5160 = torch.constant.int 1 - %int0_5161 = torch.constant.int 0 - %int9223372036854775807_5162 = torch.constant.int 9223372036854775807 + %int2_5147 = torch.constant.int 2 + %4306 = torch.aten.transpose.int %4298, %int1_5146, %int2_5147 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4306, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_5148 = torch.constant.int 1 + %int2_5149 = torch.constant.int 2 + %4307 = torch.aten.transpose.int %4304, %int1_5148, %int2_5149 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4307, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_5150 = torch.constant.float 0.000000e+00 + %false_5151 = torch.constant.bool false + %none_5152 = torch.constant.none + %4308:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4305, %4306, %4307, %float0.000000e00_5150, %false_5151, %327, %none_5152) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %4308#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_5153 = torch.constant.int 1 + %int2_5154 = torch.constant.int 2 + %4309 = torch.aten.transpose.int %4308#0, %int1_5153, %int2_5154 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4309, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_5155 = torch.constant.int 4 + %int4096_5156 = torch.constant.int 4096 + %4310 = torch.prim.ListConstruct %int4_5155, %298, %int4096_5156 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4311 = torch.aten.view %4309, %4310 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4311, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_5157 = torch.constant.int -2 + %int-1_5158 = torch.constant.int -1 + %4312 = torch.aten.transpose.int %132, %int-2_5157, %int-1_5158 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_5159 = torch.constant.int 5 + %4313 = torch.prims.convert_element_type %4312, %int5_5159 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_5160 = torch.constant.int 4096 + %4314 = torch.prim.ListConstruct %342, %int4096_5160 : (!torch.int, !torch.int) -> !torch.list + %4315 = torch.aten.view %4311, %4314 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4315, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4316 = torch.aten.mm %4315, %4313 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4316, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_5161 = torch.constant.int 4 + %int4096_5162 = torch.constant.int 4096 + %4317 = torch.prim.ListConstruct %int4_5161, %298, %int4096_5162 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4318 = torch.aten.view %4316, %4317 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4318, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> %int1_5163 = torch.constant.int 1 - %4321 = torch.aten.slice.Tensor %4320, %int1_5160, %int0_5161, %int9223372036854775807_5162, %int1_5163 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4321, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_5164 = torch.constant.int 2 - %int0_5165 = torch.constant.int 0 - %int9223372036854775807_5166 = torch.constant.int 9223372036854775807 - %int1_5167 = torch.constant.int 1 - %4322 = torch.aten.slice.Tensor %4321, %int2_5164, %int0_5165, %int9223372036854775807_5166, %int1_5167 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4322, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_5168 = torch.constant.int 4 - %int1_5169 = torch.constant.int 1 + %4319 = torch.aten.add.Tensor %4085, %4318, %int1_5163 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4319, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_5164 = torch.constant.int 6 + %4320 = torch.prims.convert_element_type %4319, %int6_5164 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4320, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_5165 = torch.constant.int 2 + %4321 = torch.aten.pow.Tensor_Scalar %4320, %int2_5165 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4321, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_5166 = torch.constant.int -1 + %4322 = torch.prim.ListConstruct %int-1_5166 : (!torch.int) -> !torch.list + %true_5167 = torch.constant.bool true + %none_5168 = torch.constant.none + %4323 = torch.aten.mean.dim %4321, %4322, %true_5167, %none_5168 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4323, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_5169 = torch.constant.float 9.9999997473787516E-6 %int1_5170 = torch.constant.int 1 - %4323 = torch.prim.ListConstruct %int4_5168, %int1_5169, %int1_5170 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4324 = torch.aten.repeat %4322, %4323 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %4324, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_5171 = torch.constant.int 6 - %4325 = torch.prims.convert_element_type %4272, %int6_5171 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %4325, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %4326 = torch_c.to_builtin_tensor %4325 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %4327 = torch_c.to_builtin_tensor %4324 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %4328 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%4326, %4327) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %4329 = torch_c.from_builtin_tensor %4328 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %4329, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> + %4324 = torch.aten.add.Scalar %4323, %float9.999990e-06_5169, %int1_5170 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4324, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4325 = torch.aten.rsqrt %4324 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4325, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4326 = torch.aten.mul.Tensor %4320, %4325 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4326, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_5171 = torch.constant.int 5 + %4327 = torch.prims.convert_element_type %4326, %int5_5171 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4327, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %4328 = torch.aten.mul.Tensor %133, %4327 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4328, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> %int5_5172 = torch.constant.int 5 - %4330 = torch.prims.convert_element_type %4329, %int5_5172 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4330, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_5173 = torch.constant.int 64 - %4331 = torch.aten.mul.Scalar %arg2, %int64_5173 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4331, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int40 = torch.constant.int 40 - %int1_5174 = torch.constant.int 1 - %4332 = torch.aten.add.Scalar %4331, %int40, %int1_5174 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4332, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_5175 = torch.constant.int 4 - %int32_5176 = torch.constant.int 32 - %int8_5177 = torch.constant.int 8 - %int128_5178 = torch.constant.int 128 - %4333 = torch.prim.ListConstruct %int4_5175, %398, %int32_5176, %int8_5177, %int128_5178 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4334 = torch.aten.view %4330, %4333 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4334, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_5179 = torch.constant.int 4 - %4335 = torch.aten.mul.int %int4_5179, %398 : !torch.int, !torch.int -> !torch.int - %int32_5180 = torch.constant.int 32 - %int8_5181 = torch.constant.int 8 - %int128_5182 = torch.constant.int 128 - %4336 = torch.prim.ListConstruct %4335, %int32_5180, %int8_5181, %int128_5182 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4337 = torch.aten.view %4334, %4336 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4337, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %4329 = torch.prims.convert_element_type %4328, %int5_5172 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4329, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_5173 = torch.constant.int -2 + %int-1_5174 = torch.constant.int -1 + %4330 = torch.aten.transpose.int %134, %int-2_5173, %int-1_5174 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_5175 = torch.constant.int 5 + %4331 = torch.prims.convert_element_type %4330, %int5_5175 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_5176 = torch.constant.int 4096 + %4332 = torch.prim.ListConstruct %342, %int4096_5176 : (!torch.int, !torch.int) -> !torch.list + %4333 = torch.aten.view %4329, %4332 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4333, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4334 = torch.aten.mm %4333, %4331 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %4334, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_5177 = torch.constant.int 4 + %int14336_5178 = torch.constant.int 14336 + %4335 = torch.prim.ListConstruct %int4_5177, %298, %int14336_5178 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4336 = torch.aten.view %4334, %4335 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4336, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %4337 = torch.aten.silu %4336 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4337, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_5179 = torch.constant.int -2 + %int-1_5180 = torch.constant.int -1 + %4338 = torch.aten.transpose.int %135, %int-2_5179, %int-1_5180 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_5181 = torch.constant.int 5 + %4339 = torch.prims.convert_element_type %4338, %int5_5181 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_5182 = torch.constant.int 4096 + %4340 = torch.prim.ListConstruct %342, %int4096_5182 : (!torch.int, !torch.int) -> !torch.list + %4341 = torch.aten.view %4329, %4340 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4341, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4342 = torch.aten.mm %4341, %4339 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %4342, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> %int4_5183 = torch.constant.int 4 - %4338 = torch.aten.mul.int %int4_5183, %398 : !torch.int, !torch.int -> !torch.int - %4339 = torch.prim.ListConstruct %4338 : (!torch.int) -> !torch.list - %4340 = torch.aten.view %4332, %4339 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4340, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_5184 = torch.constant.int 32 - %int2_5185 = torch.constant.int 2 - %int32_5186 = torch.constant.int 32 - %int8_5187 = torch.constant.int 8 - %int128_5188 = torch.constant.int 128 - %4341 = torch.prim.ListConstruct %389, %int32_5184, %int2_5185, %int32_5186, %int8_5187, %int128_5188 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4342 = torch.aten.view %4174, %4341 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4342, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5189 = torch.constant.int 32 - %4343 = torch.aten.mul.int %389, %int32_5189 : !torch.int, !torch.int -> !torch.int - %int2_5190 = torch.constant.int 2 - %4344 = torch.aten.mul.int %4343, %int2_5190 : !torch.int, !torch.int -> !torch.int - %int32_5191 = torch.constant.int 32 - %int8_5192 = torch.constant.int 8 - %int128_5193 = torch.constant.int 128 - %4345 = torch.prim.ListConstruct %4344, %int32_5191, %int8_5192, %int128_5193 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4346 = torch.aten.view %4342, %4345 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4346, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %4347 = torch.prim.ListConstruct %4340 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_5194 = torch.constant.bool false - %4348 = torch.aten.index_put %4346, %4347, %4337, %false_5194 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4348, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_5195 = torch.constant.int 32 - %int2_5196 = torch.constant.int 2 - %int32_5197 = torch.constant.int 32 - %int8_5198 = torch.constant.int 8 - %int128_5199 = torch.constant.int 128 - %4349 = torch.prim.ListConstruct %389, %int32_5195, %int2_5196, %int32_5197, %int8_5198, %int128_5199 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4350 = torch.aten.view %4348, %4349 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4350, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5200 = torch.constant.int 2097152 - %4351 = torch.prim.ListConstruct %389, %int2097152_5200 : (!torch.int, !torch.int) -> !torch.list - %4352 = torch.aten.view %4350, %4351 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4352, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_5201 = torch.constant.int 32 - %int2_5202 = torch.constant.int 2 - %int32_5203 = torch.constant.int 32 - %int8_5204 = torch.constant.int 8 - %int128_5205 = torch.constant.int 128 - %4353 = torch.prim.ListConstruct %389, %int32_5201, %int2_5202, %int32_5203, %int8_5204, %int128_5205 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4354 = torch.aten.view %4352, %4353 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4354, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5206 = torch.constant.int 32 - %int8_5207 = torch.constant.int 8 - %int128_5208 = torch.constant.int 128 - %4355 = torch.prim.ListConstruct %4344, %int32_5206, %int8_5207, %int128_5208 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4356 = torch.aten.view %4354, %4355 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4356, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_5209 = torch.constant.int 4 - %int32_5210 = torch.constant.int 32 - %int8_5211 = torch.constant.int 8 - %int128_5212 = torch.constant.int 128 - %4357 = torch.prim.ListConstruct %int4_5209, %398, %int32_5210, %int8_5211, %int128_5212 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4358 = torch.aten.view %4274, %4357 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4358, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_5213 = torch.constant.int 4 - %4359 = torch.aten.mul.int %int4_5213, %398 : !torch.int, !torch.int -> !torch.int - %int32_5214 = torch.constant.int 32 - %int8_5215 = torch.constant.int 8 - %int128_5216 = torch.constant.int 128 - %4360 = torch.prim.ListConstruct %4359, %int32_5214, %int8_5215, %int128_5216 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4361 = torch.aten.view %4358, %4360 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4361, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_5217 = torch.constant.int 1 - %int1_5218 = torch.constant.int 1 - %4362 = torch.aten.add.Scalar %4332, %int1_5217, %int1_5218 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4362, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int14336_5184 = torch.constant.int 14336 + %4343 = torch.prim.ListConstruct %int4_5183, %298, %int14336_5184 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4344 = torch.aten.view %4342, %4343 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4344, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %4345 = torch.aten.mul.Tensor %4337, %4344 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4345, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_5185 = torch.constant.int -2 + %int-1_5186 = torch.constant.int -1 + %4346 = torch.aten.transpose.int %136, %int-2_5185, %int-1_5186 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_5187 = torch.constant.int 5 + %4347 = torch.prims.convert_element_type %4346, %int5_5187 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_5188 = torch.constant.int 14336 + %4348 = torch.prim.ListConstruct %342, %int14336_5188 : (!torch.int, !torch.int) -> !torch.list + %4349 = torch.aten.view %4345, %4348 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %4349, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %4350 = torch.aten.mm %4349, %4347 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4350, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_5189 = torch.constant.int 4 + %int4096_5190 = torch.constant.int 4096 + %4351 = torch.prim.ListConstruct %int4_5189, %298, %int4096_5190 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4352 = torch.aten.view %4350, %4351 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4352, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_5191 = torch.constant.int 1 + %4353 = torch.aten.add.Tensor %4319, %4352, %int1_5191 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4353, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_5192 = torch.constant.int 6 + %4354 = torch.prims.convert_element_type %4353, %int6_5192 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4354, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_5193 = torch.constant.int 2 + %4355 = torch.aten.pow.Tensor_Scalar %4354, %int2_5193 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4355, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_5194 = torch.constant.int -1 + %4356 = torch.prim.ListConstruct %int-1_5194 : (!torch.int) -> !torch.list + %true_5195 = torch.constant.bool true + %none_5196 = torch.constant.none + %4357 = torch.aten.mean.dim %4355, %4356, %true_5195, %none_5196 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4357, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_5197 = torch.constant.float 9.9999997473787516E-6 + %int1_5198 = torch.constant.int 1 + %4358 = torch.aten.add.Scalar %4357, %float9.999990e-06_5197, %int1_5198 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4358, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4359 = torch.aten.rsqrt %4358 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4359, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4360 = torch.aten.mul.Tensor %4354, %4359 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4360, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_5199 = torch.constant.int 5 + %4361 = torch.prims.convert_element_type %4360, %int5_5199 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4361, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %4362 = torch.aten.mul.Tensor %137, %4361 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4362, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_5200 = torch.constant.int 5 + %4363 = torch.prims.convert_element_type %4362, %int5_5200 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4363, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_5201 = torch.constant.int -2 + %int-1_5202 = torch.constant.int -1 + %4364 = torch.aten.transpose.int %138, %int-2_5201, %int-1_5202 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_5203 = torch.constant.int 5 + %4365 = torch.prims.convert_element_type %4364, %int5_5203 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_5204 = torch.constant.int 4096 + %4366 = torch.prim.ListConstruct %342, %int4096_5204 : (!torch.int, !torch.int) -> !torch.list + %4367 = torch.aten.view %4363, %4366 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4367, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4368 = torch.aten.mm %4367, %4365 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4368, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_5205 = torch.constant.int 4 + %int4096_5206 = torch.constant.int 4096 + %4369 = torch.prim.ListConstruct %int4_5205, %298, %int4096_5206 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4370 = torch.aten.view %4368, %4369 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4370, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_5207 = torch.constant.int -2 + %int-1_5208 = torch.constant.int -1 + %4371 = torch.aten.transpose.int %139, %int-2_5207, %int-1_5208 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_5209 = torch.constant.int 5 + %4372 = torch.prims.convert_element_type %4371, %int5_5209 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_5210 = torch.constant.int 4096 + %4373 = torch.prim.ListConstruct %342, %int4096_5210 : (!torch.int, !torch.int) -> !torch.list + %4374 = torch.aten.view %4363, %4373 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4374, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4375 = torch.aten.mm %4374, %4372 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %4375, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_5211 = torch.constant.int 4 + %int1024_5212 = torch.constant.int 1024 + %4376 = torch.prim.ListConstruct %int4_5211, %298, %int1024_5212 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4377 = torch.aten.view %4375, %4376 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %4377, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_5213 = torch.constant.int -2 + %int-1_5214 = torch.constant.int -1 + %4378 = torch.aten.transpose.int %140, %int-2_5213, %int-1_5214 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_5215 = torch.constant.int 5 + %4379 = torch.prims.convert_element_type %4378, %int5_5215 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_5216 = torch.constant.int 4096 + %4380 = torch.prim.ListConstruct %342, %int4096_5216 : (!torch.int, !torch.int) -> !torch.list + %4381 = torch.aten.view %4363, %4380 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4381, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4382 = torch.aten.mm %4381, %4379 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %4382, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_5217 = torch.constant.int 4 + %int1024_5218 = torch.constant.int 1024 + %4383 = torch.prim.ListConstruct %int4_5217, %298, %int1024_5218 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4384 = torch.aten.view %4382, %4383 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %4384, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> %int4_5219 = torch.constant.int 4 - %4363 = torch.aten.mul.int %int4_5219, %398 : !torch.int, !torch.int -> !torch.int - %4364 = torch.prim.ListConstruct %4363 : (!torch.int) -> !torch.list - %4365 = torch.aten.view %4362, %4364 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4365, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %4366 = torch.prim.ListConstruct %4365 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_5220 = torch.constant.bool false - %4367 = torch.aten.index_put %4356, %4366, %4361, %false_5220 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4367, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_5221 = torch.constant.int 32 - %int2_5222 = torch.constant.int 2 - %int32_5223 = torch.constant.int 32 - %int8_5224 = torch.constant.int 8 - %int128_5225 = torch.constant.int 128 - %4368 = torch.prim.ListConstruct %389, %int32_5221, %int2_5222, %int32_5223, %int8_5224, %int128_5225 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4369 = torch.aten.view %4367, %4368 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4369, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5226 = torch.constant.int 2097152 - %4370 = torch.prim.ListConstruct %389, %int2097152_5226 : (!torch.int, !torch.int) -> !torch.list - %4371 = torch.aten.view %4369, %4370 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4371, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_5227 = torch.constant.int -2 - %4372 = torch.aten.unsqueeze %4330, %int-2_5227 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4372, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_5228 = torch.constant.int 4 - %int8_5229 = torch.constant.int 8 - %int4_5230 = torch.constant.int 4 - %int128_5231 = torch.constant.int 128 - %4373 = torch.prim.ListConstruct %int4_5228, %4315, %int8_5229, %int4_5230, %int128_5231 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int32_5220 = torch.constant.int 32 + %int128_5221 = torch.constant.int 128 + %4385 = torch.prim.ListConstruct %int4_5219, %298, %int32_5220, %int128_5221 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4386 = torch.aten.view %4370, %4385 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4386, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_5222 = torch.constant.int 4 + %int8_5223 = torch.constant.int 8 + %int128_5224 = torch.constant.int 128 + %4387 = torch.prim.ListConstruct %int4_5222, %298, %int8_5223, %int128_5224 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4388 = torch.aten.view %4377, %4387 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4388, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_5225 = torch.constant.int 4 + %int8_5226 = torch.constant.int 8 + %int128_5227 = torch.constant.int 128 + %4389 = torch.prim.ListConstruct %int4_5225, %298, %int8_5226, %int128_5227 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4390 = torch.aten.view %4384, %4389 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4390, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_5228 = torch.constant.int 131072 + %none_5229 = torch.constant.none + %none_5230 = torch.constant.none + %cpu_5231 = torch.constant.device "cpu" %false_5232 = torch.constant.bool false - %4374 = torch.aten.expand %4372, %4373, %false_5232 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4374, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %4391 = torch.aten.arange %int131072_5228, %none_5229, %none_5230, %cpu_5231, %false_5232 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> %int0_5233 = torch.constant.int 0 - %4375 = torch.aten.clone %4374, %int0_5233 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4375, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_5234 = torch.constant.int 4 - %int32_5235 = torch.constant.int 32 - %int128_5236 = torch.constant.int 128 - %4376 = torch.prim.ListConstruct %int4_5234, %4315, %int32_5235, %int128_5236 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4377 = torch.aten._unsafe_view %4375, %4376 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4377, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_5237 = torch.constant.int -2 - %4378 = torch.aten.unsqueeze %4274, %int-2_5237 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4378, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_5238 = torch.constant.int 1 - %4379 = torch.aten.size.int %4268, %int1_5238 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_5239 = torch.constant.int 4 - %int8_5240 = torch.constant.int 8 - %int4_5241 = torch.constant.int 4 - %int128_5242 = torch.constant.int 128 - %4380 = torch.prim.ListConstruct %int4_5239, %4379, %int8_5240, %int4_5241, %int128_5242 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_5243 = torch.constant.bool false - %4381 = torch.aten.expand %4378, %4380, %false_5243 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4381, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_5244 = torch.constant.int 0 - %4382 = torch.aten.clone %4381, %int0_5244 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4382, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_5245 = torch.constant.int 4 - %int32_5246 = torch.constant.int 32 - %int128_5247 = torch.constant.int 128 - %4383 = torch.prim.ListConstruct %int4_5245, %4379, %int32_5246, %int128_5247 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4384 = torch.aten._unsafe_view %4382, %4383 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4384, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int128_5234 = torch.constant.int 128 + %int2_5235 = torch.constant.int 2 + %int4_5236 = torch.constant.int 4 + %none_5237 = torch.constant.none + %cpu_5238 = torch.constant.device "cpu" + %false_5239 = torch.constant.bool false + %4392 = torch.aten.arange.start_step %int0_5233, %int128_5234, %int2_5235, %int4_5236, %none_5237, %cpu_5238, %false_5239 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_5240 = torch.constant.int 6 + %4393 = torch.prims.convert_element_type %4392, %int6_5240 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_5241 = torch.constant.int 128 + %4394 = torch.aten.div.Scalar %4393, %int128_5241 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_5242 = torch.constant.float 5.000000e+05 + %4395 = torch.aten.pow.Scalar %float5.000000e05_5242, %4394 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4396 = torch.aten.reciprocal %4395 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_5243 = torch.constant.float 1.000000e+00 + %4397 = torch.aten.mul.Scalar %4396, %float1.000000e00_5243 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %4398 = torch.aten.reciprocal %4397 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_5244 = torch.constant.float 6.2831853071795862 + %4399 = torch.aten.mul.Scalar %4398, %float6.283190e00_5244 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_5245 = torch.constant.float 8.192000e+03 + %4400 = torch.aten.gt.Scalar %4399, %float8.192000e03_5245 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_5246 = torch.constant.int 8 + %4401 = torch.aten.div.Scalar %4397, %int8_5246 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %4402 = torch.aten.where.self %4400, %4401, %4397 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4403 = torch.aten.reciprocal %4399 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_5247 = torch.constant.int 8192 + %4404 = torch.aten.mul.Scalar %4403, %int8192_5247 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_5248 = torch.constant.int 1 - %int2_5249 = torch.constant.int 2 - %4385 = torch.aten.transpose.int %4302, %int1_5248, %int2_5249 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4385, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_5250 = torch.constant.int 1 - %int2_5251 = torch.constant.int 2 - %4386 = torch.aten.transpose.int %4377, %int1_5250, %int2_5251 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4386, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_5249 = torch.constant.int 1 + %4405 = torch.aten.sub.Scalar %4404, %int1_5248, %int1_5249 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_5250 = torch.constant.int 3 + %4406 = torch.aten.div.Scalar %4405, %int3_5250 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_5251 = torch.constant.int 1 %int1_5252 = torch.constant.int 1 - %int2_5253 = torch.constant.int 2 - %4387 = torch.aten.transpose.int %4384, %int1_5252, %int2_5253 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4387, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_5254 = torch.constant.float 0.000000e+00 - %true_5255 = torch.constant.bool true - %none_5256 = torch.constant.none - %none_5257 = torch.constant.none - %4388:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4385, %4386, %4387, %float0.000000e00_5254, %true_5255, %none_5256, %none_5257) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %4388#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_5258 = torch.constant.int 1 - %int2_5259 = torch.constant.int 2 - %4389 = torch.aten.transpose.int %4388#0, %int1_5258, %int2_5259 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4389, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_5260 = torch.constant.int 4 - %int4096_5261 = torch.constant.int 4096 - %4390 = torch.prim.ListConstruct %int4_5260, %4287, %int4096_5261 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4391 = torch.aten.view %4389, %4390 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4391, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5262 = torch.constant.int -2 - %int-1_5263 = torch.constant.int -1 - %4392 = torch.aten.transpose.int %185, %int-2_5262, %int-1_5263 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_5264 = torch.constant.int 4 - %4393 = torch.aten.mul.int %int4_5264, %4287 : !torch.int, !torch.int -> !torch.int - %int4096_5265 = torch.constant.int 4096 - %4394 = torch.prim.ListConstruct %4393, %int4096_5265 : (!torch.int, !torch.int) -> !torch.list - %4395 = torch.aten.view %4391, %4394 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4395, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4396 = torch.aten.mm %4395, %4392 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4396, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_5266 = torch.constant.int 4 - %int4096_5267 = torch.constant.int 4096 - %4397 = torch.prim.ListConstruct %int4_5266, %4287, %int4096_5267 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4398 = torch.aten.view %4396, %4397 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4398, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %4407 = torch.aten.rsub.Scalar %4406, %int1_5251, %int1_5252 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %4408 = torch.aten.mul.Tensor %4407, %4402 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_5253 = torch.constant.int 8 + %4409 = torch.aten.div.Scalar %4408, %int8_5253 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %4410 = torch.aten.mul.Tensor %4406, %4402 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_5254 = torch.constant.int 1 + %4411 = torch.aten.add.Tensor %4409, %4410, %int1_5254 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_5255 = torch.constant.float 2.048000e+03 + %4412 = torch.aten.lt.Scalar %4399, %float2.048000e03_5255 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %4413 = torch.aten.bitwise_not %4412 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_5256 = torch.constant.float 8.192000e+03 + %4414 = torch.aten.gt.Scalar %4399, %float8.192000e03_5256 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %4415 = torch.aten.bitwise_not %4414 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %4416 = torch.aten.mul.Tensor %4413, %4415 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %4417 = torch.aten.where.self %4416, %4411, %4402 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4418 = torch.prim.ListConstruct %4417, %4417 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_5257 = torch.constant.int -1 + %4419 = torch.aten.cat %4418, %int-1_5257 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_5258 = torch.constant.int 6 + %4420 = torch.prims.convert_element_type %4419, %int6_5258 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_5259 = torch.constant.int 1 + %4421 = torch.aten.unsqueeze %4391, %int1_5259 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_5260 = torch.constant.int 6 + %4422 = torch.prims.convert_element_type %4421, %int6_5260 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_5261 = torch.constant.int 0 + %4423 = torch.aten.unsqueeze %4420, %int0_5261 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_5262 = torch.constant.int 6 + %4424 = torch.prims.convert_element_type %4423, %int6_5262 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %4425 = torch.aten.mul.Tensor %4422, %4424 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %4426 = torch.aten.cos %4425 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_5263 = torch.constant.int 5 + %4427 = torch.prims.convert_element_type %4426, %int5_5263 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %4428 = torch.aten.sin %4425 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_5264 = torch.constant.int 5 + %4429 = torch.prims.convert_element_type %4428, %int5_5264 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_5265 = torch.constant.int 0 + %int0_5266 = torch.constant.int 0 + %int1_5267 = torch.constant.int 1 + %4430 = torch.aten.slice.Tensor %4427, %int0_5265, %int0_5266, %298, %int1_5267 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4430, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_5268 = torch.constant.int 1 - %4399 = torch.aten.add.Tensor %4237, %4398, %int1_5268 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4399, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_5269 = torch.constant.int 6 - %4400 = torch.prims.convert_element_type %4399, %int6_5269 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4400, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_5270 = torch.constant.int 2 - %4401 = torch.aten.pow.Tensor_Scalar %4400, %int2_5270 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4401, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_5271 = torch.constant.int -1 - %4402 = torch.prim.ListConstruct %int-1_5271 : (!torch.int) -> !torch.list - %true_5272 = torch.constant.bool true - %none_5273 = torch.constant.none - %4403 = torch.aten.mean.dim %4401, %4402, %true_5272, %none_5273 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4403, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_5274 = torch.constant.float 9.9999997473787516E-6 + %int0_5269 = torch.constant.int 0 + %int9223372036854775807_5270 = torch.constant.int 9223372036854775807 + %int1_5271 = torch.constant.int 1 + %4431 = torch.aten.slice.Tensor %4430, %int1_5268, %int0_5269, %int9223372036854775807_5270, %int1_5271 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4431, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_5272 = torch.constant.int 0 + %int0_5273 = torch.constant.int 0 + %int1_5274 = torch.constant.int 1 + %4432 = torch.aten.slice.Tensor %4429, %int0_5272, %int0_5273, %298, %int1_5274 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4432, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_5275 = torch.constant.int 1 - %4404 = torch.aten.add.Scalar %4403, %float9.999990e-06_5274, %int1_5275 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4404, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4405 = torch.aten.rsqrt %4404 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4405, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4406 = torch.aten.mul.Tensor %4400, %4405 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4406, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_5276 = torch.constant.int 5 - %4407 = torch.prims.convert_element_type %4406, %int5_5276 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4407, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %4408 = torch.aten.mul.Tensor %186, %4407 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4408, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_5277 = torch.constant.int 5 - %4409 = torch.prims.convert_element_type %4408, %int5_5277 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4409, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5278 = torch.constant.int -2 - %int-1_5279 = torch.constant.int -1 - %4410 = torch.aten.transpose.int %187, %int-2_5278, %int-1_5279 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_5280 = torch.constant.int 4 - %4411 = torch.aten.mul.int %int4_5280, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5281 = torch.constant.int 4096 - %4412 = torch.prim.ListConstruct %4411, %int4096_5281 : (!torch.int, !torch.int) -> !torch.list - %4413 = torch.aten.view %4409, %4412 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4413, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4414 = torch.aten.mm %4413, %4410 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4414, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_5282 = torch.constant.int 4 - %int14336_5283 = torch.constant.int 14336 - %4415 = torch.prim.ListConstruct %int4_5282, %306, %int14336_5283 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4416 = torch.aten.view %4414, %4415 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4416, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %4417 = torch.aten.silu %4416 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4417, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_5284 = torch.constant.int -2 - %int-1_5285 = torch.constant.int -1 - %4418 = torch.aten.transpose.int %188, %int-2_5284, %int-1_5285 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_5286 = torch.constant.int 4 - %4419 = torch.aten.mul.int %int4_5286, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5287 = torch.constant.int 4096 - %4420 = torch.prim.ListConstruct %4419, %int4096_5287 : (!torch.int, !torch.int) -> !torch.list - %4421 = torch.aten.view %4409, %4420 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4421, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4422 = torch.aten.mm %4421, %4418 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4422, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_5288 = torch.constant.int 4 - %int14336_5289 = torch.constant.int 14336 - %4423 = torch.prim.ListConstruct %int4_5288, %306, %int14336_5289 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4424 = torch.aten.view %4422, %4423 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4424, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %4425 = torch.aten.mul.Tensor %4417, %4424 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4425, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_5290 = torch.constant.int -2 - %int-1_5291 = torch.constant.int -1 - %4426 = torch.aten.transpose.int %189, %int-2_5290, %int-1_5291 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int0_5276 = torch.constant.int 0 + %int9223372036854775807_5277 = torch.constant.int 9223372036854775807 + %int1_5278 = torch.constant.int 1 + %4433 = torch.aten.slice.Tensor %4432, %int1_5275, %int0_5276, %int9223372036854775807_5277, %int1_5278 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4433, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_5279 = torch.constant.int 0 + %4434 = torch.aten.unsqueeze %4431, %int0_5279 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4434, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_5280 = torch.constant.int 1 + %int0_5281 = torch.constant.int 0 + %int9223372036854775807_5282 = torch.constant.int 9223372036854775807 + %int1_5283 = torch.constant.int 1 + %4435 = torch.aten.slice.Tensor %4434, %int1_5280, %int0_5281, %int9223372036854775807_5282, %int1_5283 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4435, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_5284 = torch.constant.int 2 + %4436 = torch.aten.unsqueeze %4435, %int2_5284 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4436, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_5285 = torch.constant.int 3 + %int0_5286 = torch.constant.int 0 + %int9223372036854775807_5287 = torch.constant.int 9223372036854775807 + %int1_5288 = torch.constant.int 1 + %4437 = torch.aten.slice.Tensor %4436, %int3_5285, %int0_5286, %int9223372036854775807_5287, %int1_5288 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4437, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_5289 = torch.constant.int 4 + %int1_5290 = torch.constant.int 1 + %int1_5291 = torch.constant.int 1 %int1_5292 = torch.constant.int 1 - %4427 = torch.aten.size.int %4416, %int1_5292 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_5293 = torch.constant.int 4 - %4428 = torch.aten.mul.int %int4_5293, %4427 : !torch.int, !torch.int -> !torch.int - %int14336_5294 = torch.constant.int 14336 - %4429 = torch.prim.ListConstruct %4428, %int14336_5294 : (!torch.int, !torch.int) -> !torch.list - %4430 = torch.aten.view %4425, %4429 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4430, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %4431 = torch.aten.mm %4430, %4426 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4431, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_5295 = torch.constant.int 4 - %int4096_5296 = torch.constant.int 4096 - %4432 = torch.prim.ListConstruct %int4_5295, %4427, %int4096_5296 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4433 = torch.aten.view %4431, %4432 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4433, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %4438 = torch.prim.ListConstruct %int4_5289, %int1_5290, %int1_5291, %int1_5292 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4439 = torch.aten.repeat %4437, %4438 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %4439, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_5293 = torch.constant.int 0 + %4440 = torch.aten.unsqueeze %4433, %int0_5293 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4440, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_5294 = torch.constant.int 1 + %int0_5295 = torch.constant.int 0 + %int9223372036854775807_5296 = torch.constant.int 9223372036854775807 %int1_5297 = torch.constant.int 1 - %4434 = torch.aten.add.Tensor %4399, %4433, %int1_5297 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4434, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_5298 = torch.constant.int 6 - %4435 = torch.prims.convert_element_type %4434, %int6_5298 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4435, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_5299 = torch.constant.int 2 - %4436 = torch.aten.pow.Tensor_Scalar %4435, %int2_5299 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4436, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_5300 = torch.constant.int -1 - %4437 = torch.prim.ListConstruct %int-1_5300 : (!torch.int) -> !torch.list - %true_5301 = torch.constant.bool true - %none_5302 = torch.constant.none - %4438 = torch.aten.mean.dim %4436, %4437, %true_5301, %none_5302 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4438, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_5303 = torch.constant.float 9.9999997473787516E-6 + %4441 = torch.aten.slice.Tensor %4440, %int1_5294, %int0_5295, %int9223372036854775807_5296, %int1_5297 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4441, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_5298 = torch.constant.int 2 + %4442 = torch.aten.unsqueeze %4441, %int2_5298 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4442, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_5299 = torch.constant.int 3 + %int0_5300 = torch.constant.int 0 + %int9223372036854775807_5301 = torch.constant.int 9223372036854775807 + %int1_5302 = torch.constant.int 1 + %4443 = torch.aten.slice.Tensor %4442, %int3_5299, %int0_5300, %int9223372036854775807_5301, %int1_5302 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4443, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_5303 = torch.constant.int 4 %int1_5304 = torch.constant.int 1 - %4439 = torch.aten.add.Scalar %4438, %float9.999990e-06_5303, %int1_5304 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4439, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4440 = torch.aten.rsqrt %4439 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4440, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4441 = torch.aten.mul.Tensor %4435, %4440 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4441, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_5305 = torch.constant.int 5 - %4442 = torch.prims.convert_element_type %4441, %int5_5305 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4442, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %4443 = torch.aten.mul.Tensor %190, %4442 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4443, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_5306 = torch.constant.int 5 - %4444 = torch.prims.convert_element_type %4443, %int5_5306 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4444, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5307 = torch.constant.int -2 - %int-1_5308 = torch.constant.int -1 - %4445 = torch.aten.transpose.int %191, %int-2_5307, %int-1_5308 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_5309 = torch.constant.int 4 - %4446 = torch.aten.mul.int %int4_5309, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5310 = torch.constant.int 4096 - %4447 = torch.prim.ListConstruct %4446, %int4096_5310 : (!torch.int, !torch.int) -> !torch.list - %4448 = torch.aten.view %4444, %4447 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4448, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4449 = torch.aten.mm %4448, %4445 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4449, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_5311 = torch.constant.int 4 - %int4096_5312 = torch.constant.int 4096 - %4450 = torch.prim.ListConstruct %int4_5311, %306, %int4096_5312 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4451 = torch.aten.view %4449, %4450 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4451, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5313 = torch.constant.int -2 - %int-1_5314 = torch.constant.int -1 - %4452 = torch.aten.transpose.int %192, %int-2_5313, %int-1_5314 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_5315 = torch.constant.int 4 - %4453 = torch.aten.mul.int %int4_5315, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5316 = torch.constant.int 4096 - %4454 = torch.prim.ListConstruct %4453, %int4096_5316 : (!torch.int, !torch.int) -> !torch.list - %4455 = torch.aten.view %4444, %4454 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4455, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4456 = torch.aten.mm %4455, %4452 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %4456, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_5317 = torch.constant.int 4 - %int1024_5318 = torch.constant.int 1024 - %4457 = torch.prim.ListConstruct %int4_5317, %306, %int1024_5318 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4458 = torch.aten.view %4456, %4457 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %4458, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_5319 = torch.constant.int -2 - %int-1_5320 = torch.constant.int -1 - %4459 = torch.aten.transpose.int %193, %int-2_5319, %int-1_5320 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_5321 = torch.constant.int 4 - %4460 = torch.aten.mul.int %int4_5321, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5322 = torch.constant.int 4096 - %4461 = torch.prim.ListConstruct %4460, %int4096_5322 : (!torch.int, !torch.int) -> !torch.list - %4462 = torch.aten.view %4444, %4461 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4462, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4463 = torch.aten.mm %4462, %4459 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %4463, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_5323 = torch.constant.int 4 - %int1024_5324 = torch.constant.int 1024 - %4464 = torch.prim.ListConstruct %int4_5323, %306, %int1024_5324 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4465 = torch.aten.view %4463, %4464 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %4465, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int1_5305 = torch.constant.int 1 + %int1_5306 = torch.constant.int 1 + %4444 = torch.prim.ListConstruct %int4_5303, %int1_5304, %int1_5305, %int1_5306 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4445 = torch.aten.repeat %4443, %4444 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %4445, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %4446 = torch.aten.mul.Tensor %4386, %4439 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4446, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_5307 = torch.constant.int 3 + %int0_5308 = torch.constant.int 0 + %int64_5309 = torch.constant.int 64 + %int1_5310 = torch.constant.int 1 + %4447 = torch.aten.slice.Tensor %4386, %int3_5307, %int0_5308, %int64_5309, %int1_5310 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %4447, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_5311 = torch.constant.int 3 + %int64_5312 = torch.constant.int 64 + %int9223372036854775807_5313 = torch.constant.int 9223372036854775807 + %int1_5314 = torch.constant.int 1 + %4448 = torch.aten.slice.Tensor %4386, %int3_5311, %int64_5312, %int9223372036854775807_5313, %int1_5314 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %4448, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %4449 = torch.aten.neg %4448 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %4449, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %4450 = torch.prim.ListConstruct %4449, %4447 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_5315 = torch.constant.int -1 + %4451 = torch.aten.cat %4450, %int-1_5315 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4451, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %4452 = torch.aten.mul.Tensor %4451, %4445 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4452, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_5316 = torch.constant.int 1 + %4453 = torch.aten.add.Tensor %4446, %4452, %int1_5316 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4453, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_5317 = torch.constant.int 131072 + %none_5318 = torch.constant.none + %none_5319 = torch.constant.none + %cpu_5320 = torch.constant.device "cpu" + %false_5321 = torch.constant.bool false + %4454 = torch.aten.arange %int131072_5317, %none_5318, %none_5319, %cpu_5320, %false_5321 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_5322 = torch.constant.int 0 + %int128_5323 = torch.constant.int 128 + %int2_5324 = torch.constant.int 2 %int4_5325 = torch.constant.int 4 - %int32_5326 = torch.constant.int 32 - %int128_5327 = torch.constant.int 128 - %4466 = torch.prim.ListConstruct %int4_5325, %306, %int32_5326, %int128_5327 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4467 = torch.aten.view %4451, %4466 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4467, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_5328 = torch.constant.int 4 - %int8_5329 = torch.constant.int 8 + %none_5326 = torch.constant.none + %cpu_5327 = torch.constant.device "cpu" + %false_5328 = torch.constant.bool false + %4455 = torch.aten.arange.start_step %int0_5322, %int128_5323, %int2_5324, %int4_5325, %none_5326, %cpu_5327, %false_5328 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_5329 = torch.constant.int 6 + %4456 = torch.prims.convert_element_type %4455, %int6_5329 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> %int128_5330 = torch.constant.int 128 - %4468 = torch.prim.ListConstruct %int4_5328, %306, %int8_5329, %int128_5330 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4469 = torch.aten.view %4458, %4468 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4469, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_5331 = torch.constant.int 4 - %int8_5332 = torch.constant.int 8 - %int128_5333 = torch.constant.int 128 - %4470 = torch.prim.ListConstruct %int4_5331, %306, %int8_5332, %int128_5333 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4471 = torch.aten.view %4465, %4470 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4471, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_5334 = torch.constant.int 131072 - %none_5335 = torch.constant.none - %none_5336 = torch.constant.none - %cpu_5337 = torch.constant.device "cpu" - %false_5338 = torch.constant.bool false - %4472 = torch.aten.arange %int131072_5334, %none_5335, %none_5336, %cpu_5337, %false_5338 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_5339 = torch.constant.int 0 - %int128_5340 = torch.constant.int 128 - %none_5341 = torch.constant.none - %none_5342 = torch.constant.none - %cpu_5343 = torch.constant.device "cpu" - %false_5344 = torch.constant.bool false - %4473 = torch.aten.arange.start %int0_5339, %int128_5340, %none_5341, %none_5342, %cpu_5343, %false_5344 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_5345 = torch.constant.int 2 - %4474 = torch.aten.floor_divide.Scalar %4473, %int2_5345 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_5346 = torch.constant.int 6 - %4475 = torch.prims.convert_element_type %4474, %int6_5346 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_5347 = torch.constant.int 128 - %4476 = torch.aten.div.Scalar %4475, %int128_5347 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_5348 = torch.constant.float 2.000000e+00 - %4477 = torch.aten.mul.Scalar %4476, %float2.000000e00_5348 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_5349 = torch.constant.float 5.000000e+05 - %4478 = torch.aten.pow.Scalar %float5.000000e05_5349, %4477 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %4479 = torch.aten.reciprocal %4478 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_5350 = torch.constant.float 1.000000e+00 - %4480 = torch.aten.mul.Scalar %4479, %float1.000000e00_5350 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_5351 = torch.constant.int 1 - %4481 = torch.aten.unsqueeze %4472, %int1_5351 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_5352 = torch.constant.int 0 - %4482 = torch.aten.unsqueeze %4480, %int0_5352 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %4483 = torch.aten.mul.Tensor %4481, %4482 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_5353 = torch.constant.int 1 - %4484 = torch.aten.size.int %4451, %int1_5353 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int + %4457 = torch.aten.div.Scalar %4456, %int128_5330 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_5331 = torch.constant.float 5.000000e+05 + %4458 = torch.aten.pow.Scalar %float5.000000e05_5331, %4457 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4459 = torch.aten.reciprocal %4458 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_5332 = torch.constant.float 1.000000e+00 + %4460 = torch.aten.mul.Scalar %4459, %float1.000000e00_5332 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %4461 = torch.aten.reciprocal %4460 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_5333 = torch.constant.float 6.2831853071795862 + %4462 = torch.aten.mul.Scalar %4461, %float6.283190e00_5333 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_5334 = torch.constant.float 8.192000e+03 + %4463 = torch.aten.gt.Scalar %4462, %float8.192000e03_5334 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_5335 = torch.constant.int 8 + %4464 = torch.aten.div.Scalar %4460, %int8_5335 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %4465 = torch.aten.where.self %4463, %4464, %4460 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4466 = torch.aten.reciprocal %4462 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_5336 = torch.constant.int 8192 + %4467 = torch.aten.mul.Scalar %4466, %int8192_5336 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_5337 = torch.constant.int 1 + %int1_5338 = torch.constant.int 1 + %4468 = torch.aten.sub.Scalar %4467, %int1_5337, %int1_5338 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_5339 = torch.constant.int 3 + %4469 = torch.aten.div.Scalar %4468, %int3_5339 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_5340 = torch.constant.int 1 + %int1_5341 = torch.constant.int 1 + %4470 = torch.aten.rsub.Scalar %4469, %int1_5340, %int1_5341 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %4471 = torch.aten.mul.Tensor %4470, %4465 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_5342 = torch.constant.int 8 + %4472 = torch.aten.div.Scalar %4471, %int8_5342 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %4473 = torch.aten.mul.Tensor %4469, %4465 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_5343 = torch.constant.int 1 + %4474 = torch.aten.add.Tensor %4472, %4473, %int1_5343 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_5344 = torch.constant.float 2.048000e+03 + %4475 = torch.aten.lt.Scalar %4462, %float2.048000e03_5344 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %4476 = torch.aten.bitwise_not %4475 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_5345 = torch.constant.float 8.192000e+03 + %4477 = torch.aten.gt.Scalar %4462, %float8.192000e03_5345 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %4478 = torch.aten.bitwise_not %4477 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %4479 = torch.aten.mul.Tensor %4476, %4478 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %4480 = torch.aten.where.self %4479, %4474, %4465 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4481 = torch.prim.ListConstruct %4480, %4480 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_5346 = torch.constant.int -1 + %4482 = torch.aten.cat %4481, %int-1_5346 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_5347 = torch.constant.int 6 + %4483 = torch.prims.convert_element_type %4482, %int6_5347 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_5348 = torch.constant.int 1 + %4484 = torch.aten.unsqueeze %4454, %int1_5348 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_5349 = torch.constant.int 6 + %4485 = torch.prims.convert_element_type %4484, %int6_5349 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_5350 = torch.constant.int 0 + %4486 = torch.aten.unsqueeze %4483, %int0_5350 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_5351 = torch.constant.int 6 + %4487 = torch.prims.convert_element_type %4486, %int6_5351 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %4488 = torch.aten.mul.Tensor %4485, %4487 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %4489 = torch.aten.cos %4488 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_5352 = torch.constant.int 5 + %4490 = torch.prims.convert_element_type %4489, %int5_5352 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %4491 = torch.aten.sin %4488 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_5353 = torch.constant.int 5 + %4492 = torch.prims.convert_element_type %4491, %int5_5353 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> %int0_5354 = torch.constant.int 0 - %4485 = torch.aten.add.int %int0_5354, %4484 : !torch.int, !torch.int -> !torch.int %int0_5355 = torch.constant.int 0 - %int0_5356 = torch.constant.int 0 + %int1_5356 = torch.constant.int 1 + %4493 = torch.aten.slice.Tensor %4490, %int0_5354, %int0_5355, %298, %int1_5356 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4493, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_5357 = torch.constant.int 1 - %4486 = torch.aten.slice.Tensor %4483, %int0_5355, %int0_5356, %4485, %int1_5357 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4486, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_5358 = torch.constant.int 1 - %int0_5359 = torch.constant.int 0 - %int9223372036854775807_5360 = torch.constant.int 9223372036854775807 - %int1_5361 = torch.constant.int 1 - %4487 = torch.aten.slice.Tensor %4486, %int1_5358, %int0_5359, %int9223372036854775807_5360, %int1_5361 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4487, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_5362 = torch.constant.int 1 - %int0_5363 = torch.constant.int 0 - %int9223372036854775807_5364 = torch.constant.int 9223372036854775807 - %int1_5365 = torch.constant.int 1 - %4488 = torch.aten.slice.Tensor %4487, %int1_5362, %int0_5363, %int9223372036854775807_5364, %int1_5365 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4488, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_5366 = torch.constant.int 0 - %4489 = torch.aten.unsqueeze %4488, %int0_5366 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4489, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %int0_5358 = torch.constant.int 0 + %int9223372036854775807_5359 = torch.constant.int 9223372036854775807 + %int1_5360 = torch.constant.int 1 + %4494 = torch.aten.slice.Tensor %4493, %int1_5357, %int0_5358, %int9223372036854775807_5359, %int1_5360 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4494, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_5361 = torch.constant.int 0 + %int0_5362 = torch.constant.int 0 + %int1_5363 = torch.constant.int 1 + %4495 = torch.aten.slice.Tensor %4492, %int0_5361, %int0_5362, %298, %int1_5363 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4495, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_5364 = torch.constant.int 1 + %int0_5365 = torch.constant.int 0 + %int9223372036854775807_5366 = torch.constant.int 9223372036854775807 %int1_5367 = torch.constant.int 1 + %4496 = torch.aten.slice.Tensor %4495, %int1_5364, %int0_5365, %int9223372036854775807_5366, %int1_5367 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4496, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int0_5368 = torch.constant.int 0 - %int9223372036854775807_5369 = torch.constant.int 9223372036854775807 - %int1_5370 = torch.constant.int 1 - %4490 = torch.aten.slice.Tensor %4489, %int1_5367, %int0_5368, %int9223372036854775807_5369, %int1_5370 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4490, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_5371 = torch.constant.int 2 - %int0_5372 = torch.constant.int 0 - %int9223372036854775807_5373 = torch.constant.int 9223372036854775807 - %int1_5374 = torch.constant.int 1 - %4491 = torch.aten.slice.Tensor %4490, %int2_5371, %int0_5372, %int9223372036854775807_5373, %int1_5374 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4491, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_5375 = torch.constant.int 4 - %int1_5376 = torch.constant.int 1 + %4497 = torch.aten.unsqueeze %4494, %int0_5368 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4497, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_5369 = torch.constant.int 1 + %int0_5370 = torch.constant.int 0 + %int9223372036854775807_5371 = torch.constant.int 9223372036854775807 + %int1_5372 = torch.constant.int 1 + %4498 = torch.aten.slice.Tensor %4497, %int1_5369, %int0_5370, %int9223372036854775807_5371, %int1_5372 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4498, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_5373 = torch.constant.int 2 + %4499 = torch.aten.unsqueeze %4498, %int2_5373 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4499, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_5374 = torch.constant.int 3 + %int0_5375 = torch.constant.int 0 + %int9223372036854775807_5376 = torch.constant.int 9223372036854775807 %int1_5377 = torch.constant.int 1 - %4492 = torch.prim.ListConstruct %int4_5375, %int1_5376, %int1_5377 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4493 = torch.aten.repeat %4491, %4492 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %4493, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_5378 = torch.constant.int 6 - %4494 = torch.prims.convert_element_type %4467, %int6_5378 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %4494, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %4495 = torch_c.to_builtin_tensor %4494 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %4496 = torch_c.to_builtin_tensor %4493 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %4497 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%4495, %4496) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %4498 = torch_c.from_builtin_tensor %4497 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %4498, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_5379 = torch.constant.int 5 - %4499 = torch.prims.convert_element_type %4498, %int5_5379 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4499, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_5380 = torch.constant.int 131072 - %none_5381 = torch.constant.none - %none_5382 = torch.constant.none - %cpu_5383 = torch.constant.device "cpu" - %false_5384 = torch.constant.bool false - %4500 = torch.aten.arange %int131072_5380, %none_5381, %none_5382, %cpu_5383, %false_5384 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_5385 = torch.constant.int 0 - %int128_5386 = torch.constant.int 128 - %none_5387 = torch.constant.none - %none_5388 = torch.constant.none - %cpu_5389 = torch.constant.device "cpu" - %false_5390 = torch.constant.bool false - %4501 = torch.aten.arange.start %int0_5385, %int128_5386, %none_5387, %none_5388, %cpu_5389, %false_5390 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_5391 = torch.constant.int 2 - %4502 = torch.aten.floor_divide.Scalar %4501, %int2_5391 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_5392 = torch.constant.int 6 - %4503 = torch.prims.convert_element_type %4502, %int6_5392 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_5393 = torch.constant.int 128 - %4504 = torch.aten.div.Scalar %4503, %int128_5393 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_5394 = torch.constant.float 2.000000e+00 - %4505 = torch.aten.mul.Scalar %4504, %float2.000000e00_5394 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_5395 = torch.constant.float 5.000000e+05 - %4506 = torch.aten.pow.Scalar %float5.000000e05_5395, %4505 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %4507 = torch.aten.reciprocal %4506 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_5396 = torch.constant.float 1.000000e+00 - %4508 = torch.aten.mul.Scalar %4507, %float1.000000e00_5396 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_5397 = torch.constant.int 1 - %4509 = torch.aten.unsqueeze %4500, %int1_5397 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_5398 = torch.constant.int 0 - %4510 = torch.aten.unsqueeze %4508, %int0_5398 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %4511 = torch.aten.mul.Tensor %4509, %4510 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %4500 = torch.aten.slice.Tensor %4499, %int3_5374, %int0_5375, %int9223372036854775807_5376, %int1_5377 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4500, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_5378 = torch.constant.int 4 + %int1_5379 = torch.constant.int 1 + %int1_5380 = torch.constant.int 1 + %int1_5381 = torch.constant.int 1 + %4501 = torch.prim.ListConstruct %int4_5378, %int1_5379, %int1_5380, %int1_5381 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4502 = torch.aten.repeat %4500, %4501 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %4502, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_5382 = torch.constant.int 0 + %4503 = torch.aten.unsqueeze %4496, %int0_5382 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4503, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_5383 = torch.constant.int 1 + %int0_5384 = torch.constant.int 0 + %int9223372036854775807_5385 = torch.constant.int 9223372036854775807 + %int1_5386 = torch.constant.int 1 + %4504 = torch.aten.slice.Tensor %4503, %int1_5383, %int0_5384, %int9223372036854775807_5385, %int1_5386 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4504, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_5387 = torch.constant.int 2 + %4505 = torch.aten.unsqueeze %4504, %int2_5387 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4505, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_5388 = torch.constant.int 3 + %int0_5389 = torch.constant.int 0 + %int9223372036854775807_5390 = torch.constant.int 9223372036854775807 + %int1_5391 = torch.constant.int 1 + %4506 = torch.aten.slice.Tensor %4505, %int3_5388, %int0_5389, %int9223372036854775807_5390, %int1_5391 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4506, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_5392 = torch.constant.int 4 + %int1_5393 = torch.constant.int 1 + %int1_5394 = torch.constant.int 1 + %int1_5395 = torch.constant.int 1 + %4507 = torch.prim.ListConstruct %int4_5392, %int1_5393, %int1_5394, %int1_5395 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4508 = torch.aten.repeat %4506, %4507 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %4508, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %4509 = torch.aten.mul.Tensor %4388, %4502 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4509, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_5396 = torch.constant.int 3 + %int0_5397 = torch.constant.int 0 + %int64_5398 = torch.constant.int 64 %int1_5399 = torch.constant.int 1 - %4512 = torch.aten.size.int %4458, %int1_5399 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_5400 = torch.constant.int 0 - %4513 = torch.aten.add.int %int0_5400, %4512 : !torch.int, !torch.int -> !torch.int - %int0_5401 = torch.constant.int 0 - %int0_5402 = torch.constant.int 0 + %4510 = torch.aten.slice.Tensor %4388, %int3_5396, %int0_5397, %int64_5398, %int1_5399 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %4510, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_5400 = torch.constant.int 3 + %int64_5401 = torch.constant.int 64 + %int9223372036854775807_5402 = torch.constant.int 9223372036854775807 %int1_5403 = torch.constant.int 1 - %4514 = torch.aten.slice.Tensor %4511, %int0_5401, %int0_5402, %4513, %int1_5403 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4514, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_5404 = torch.constant.int 1 - %int0_5405 = torch.constant.int 0 - %int9223372036854775807_5406 = torch.constant.int 9223372036854775807 + %4511 = torch.aten.slice.Tensor %4388, %int3_5400, %int64_5401, %int9223372036854775807_5402, %int1_5403 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %4511, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %4512 = torch.aten.neg %4511 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %4512, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %4513 = torch.prim.ListConstruct %4512, %4510 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_5404 = torch.constant.int -1 + %4514 = torch.aten.cat %4513, %int-1_5404 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4514, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %4515 = torch.aten.mul.Tensor %4514, %4508 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4515, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_5405 = torch.constant.int 1 + %4516 = torch.aten.add.Tensor %4509, %4515, %int1_5405 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4516, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_5406 = torch.constant.int 32 + %4517 = torch.aten.mul.Scalar %arg2, %int32_5406 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4517, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int15 = torch.constant.int 15 %int1_5407 = torch.constant.int 1 - %4515 = torch.aten.slice.Tensor %4514, %int1_5404, %int0_5405, %int9223372036854775807_5406, %int1_5407 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4515, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_5408 = torch.constant.int 1 + %4518 = torch.aten.add.Scalar %4517, %int15, %int1_5407 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4518, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_5408 = torch.constant.int 2 + %4519 = torch.aten.mul.Scalar %4518, %int2_5408 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4519, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> %int0_5409 = torch.constant.int 0 - %int9223372036854775807_5410 = torch.constant.int 9223372036854775807 - %int1_5411 = torch.constant.int 1 - %4516 = torch.aten.slice.Tensor %4515, %int1_5408, %int0_5409, %int9223372036854775807_5410, %int1_5411 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4516, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_5412 = torch.constant.int 0 - %4517 = torch.aten.unsqueeze %4516, %int0_5412 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4517, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_5413 = torch.constant.int 1 - %int0_5414 = torch.constant.int 0 - %int9223372036854775807_5415 = torch.constant.int 9223372036854775807 - %int1_5416 = torch.constant.int 1 - %4518 = torch.aten.slice.Tensor %4517, %int1_5413, %int0_5414, %int9223372036854775807_5415, %int1_5416 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4518, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_5417 = torch.constant.int 2 - %int0_5418 = torch.constant.int 0 - %int9223372036854775807_5419 = torch.constant.int 9223372036854775807 - %int1_5420 = torch.constant.int 1 - %4519 = torch.aten.slice.Tensor %4518, %int2_5417, %int0_5418, %int9223372036854775807_5419, %int1_5420 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4519, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_5421 = torch.constant.int 4 - %int1_5422 = torch.constant.int 1 - %int1_5423 = torch.constant.int 1 - %4520 = torch.prim.ListConstruct %int4_5421, %int1_5422, %int1_5423 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4521 = torch.aten.repeat %4519, %4520 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %4521, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_5424 = torch.constant.int 6 - %4522 = torch.prims.convert_element_type %4469, %int6_5424 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %4522, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %4523 = torch_c.to_builtin_tensor %4522 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %4524 = torch_c.to_builtin_tensor %4521 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %4525 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%4523, %4524) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %4526 = torch_c.from_builtin_tensor %4525 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %4526, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_5425 = torch.constant.int 5 - %4527 = torch.prims.convert_element_type %4526, %int5_5425 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4527, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_5426 = torch.constant.int 64 - %4528 = torch.aten.mul.Scalar %arg2, %int64_5426 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4528, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int42 = torch.constant.int 42 - %int1_5427 = torch.constant.int 1 - %4529 = torch.aten.add.Scalar %4528, %int42, %int1_5427 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4529, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_5428 = torch.constant.int 4 - %int32_5429 = torch.constant.int 32 - %int8_5430 = torch.constant.int 8 - %int128_5431 = torch.constant.int 128 - %4530 = torch.prim.ListConstruct %int4_5428, %398, %int32_5429, %int8_5430, %int128_5431 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4531 = torch.aten.view %4527, %4530 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4531, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_5432 = torch.constant.int 4 - %4532 = torch.aten.mul.int %int4_5432, %398 : !torch.int, !torch.int -> !torch.int + %int1_5410 = torch.constant.int 1 + %4520 = torch.aten.add.Scalar %4519, %int0_5409, %int1_5410 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4520, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %4521 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %4522 = torch.aten.view %4520, %4521 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %4522, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_5411 = torch.constant.int 4 + %int32_5412 = torch.constant.int 32 + %int8_5413 = torch.constant.int 8 + %int128_5414 = torch.constant.int 128 + %4523 = torch.prim.ListConstruct %int4_5411, %296, %int32_5412, %int8_5413, %int128_5414 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4524 = torch.aten.view %4516, %4523 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4524, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_5415 = torch.constant.int 32 + %int8_5416 = torch.constant.int 8 + %int128_5417 = torch.constant.int 128 + %4525 = torch.prim.ListConstruct %504, %int32_5415, %int8_5416, %int128_5417 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4526 = torch.aten.view %4524, %4525 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %4526, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_5418 = torch.constant.int 1 + %int2_5419 = torch.constant.int 2 + %4527 = torch.aten.transpose.int %4526, %int1_5418, %int2_5419 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4527, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_5420 = torch.constant.int 5 + %4528 = torch.prims.convert_element_type %4527, %int5_5420 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4528, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_5421 = torch.constant.int 32 + %int2_5422 = torch.constant.int 2 + %int8_5423 = torch.constant.int 8 + %int32_5424 = torch.constant.int 32 + %int128_5425 = torch.constant.int 128 + %4529 = torch.prim.ListConstruct %297, %int32_5421, %int2_5422, %int8_5423, %int32_5424, %int128_5425 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4530 = torch.aten.view %4292, %4529 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4530, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_5426 = torch.constant.int 8 + %int32_5427 = torch.constant.int 32 + %int128_5428 = torch.constant.int 128 + %4531 = torch.prim.ListConstruct %497, %int8_5426, %int32_5427, %int128_5428 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4532 = torch.aten.view %4530, %4531 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4532, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %4533 = torch.prim.ListConstruct %4522 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_5429 = torch.constant.bool false + %4534 = torch.aten.index_put %4532, %4533, %4528, %false_5429 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4534, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_5430 = torch.constant.int 32 + %int2_5431 = torch.constant.int 2 + %int8_5432 = torch.constant.int 8 %int32_5433 = torch.constant.int 32 - %int8_5434 = torch.constant.int 8 - %int128_5435 = torch.constant.int 128 - %4533 = torch.prim.ListConstruct %4532, %int32_5433, %int8_5434, %int128_5435 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4534 = torch.aten.view %4531, %4533 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4534, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_5436 = torch.constant.int 4 - %4535 = torch.aten.mul.int %int4_5436, %398 : !torch.int, !torch.int -> !torch.int - %4536 = torch.prim.ListConstruct %4535 : (!torch.int) -> !torch.list - %4537 = torch.aten.view %4529, %4536 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4537, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_5437 = torch.constant.int 32 - %int2_5438 = torch.constant.int 2 + %int128_5434 = torch.constant.int 128 + %4535 = torch.prim.ListConstruct %297, %int32_5430, %int2_5431, %int8_5432, %int32_5433, %int128_5434 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4536 = torch.aten.view %4534, %4535 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4536, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_5435 = torch.constant.int 2097152 + %4537 = torch.prim.ListConstruct %297, %int2097152_5435 : (!torch.int, !torch.int) -> !torch.list + %4538 = torch.aten.view %4536, %4537 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4538, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_5436 = torch.constant.int 32 + %int2_5437 = torch.constant.int 2 + %int8_5438 = torch.constant.int 8 %int32_5439 = torch.constant.int 32 - %int8_5440 = torch.constant.int 8 - %int128_5441 = torch.constant.int 128 - %4538 = torch.prim.ListConstruct %389, %int32_5437, %int2_5438, %int32_5439, %int8_5440, %int128_5441 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4539 = torch.aten.view %4371, %4538 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4539, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> + %int128_5440 = torch.constant.int 128 + %4539 = torch.prim.ListConstruct %297, %int32_5436, %int2_5437, %int8_5438, %int32_5439, %int128_5440 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4540 = torch.aten.view %4538, %4539 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4540, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_5441 = torch.constant.int 8 %int32_5442 = torch.constant.int 32 - %4540 = torch.aten.mul.int %389, %int32_5442 : !torch.int, !torch.int -> !torch.int - %int2_5443 = torch.constant.int 2 - %4541 = torch.aten.mul.int %4540, %int2_5443 : !torch.int, !torch.int -> !torch.int + %int128_5443 = torch.constant.int 128 + %4541 = torch.prim.ListConstruct %497, %int8_5441, %int32_5442, %int128_5443 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4542 = torch.aten.view %4540, %4541 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4542, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> %int32_5444 = torch.constant.int 32 - %int8_5445 = torch.constant.int 8 - %int128_5446 = torch.constant.int 128 - %4542 = torch.prim.ListConstruct %4541, %int32_5444, %int8_5445, %int128_5446 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4543 = torch.aten.view %4539, %4542 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4543, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %4544 = torch.prim.ListConstruct %4537 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_5447 = torch.constant.bool false - %4545 = torch.aten.index_put %4543, %4544, %4534, %false_5447 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4545, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_5448 = torch.constant.int 32 - %int2_5449 = torch.constant.int 2 - %int32_5450 = torch.constant.int 32 - %int8_5451 = torch.constant.int 8 - %int128_5452 = torch.constant.int 128 - %4546 = torch.prim.ListConstruct %389, %int32_5448, %int2_5449, %int32_5450, %int8_5451, %int128_5452 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4547 = torch.aten.view %4545, %4546 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4547, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5453 = torch.constant.int 2097152 - %4548 = torch.prim.ListConstruct %389, %int2097152_5453 : (!torch.int, !torch.int) -> !torch.list - %4549 = torch.aten.view %4547, %4548 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4549, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %4543 = torch.aten.mul.Scalar %arg2, %int32_5444 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4543, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int15_5445 = torch.constant.int 15 + %int1_5446 = torch.constant.int 1 + %4544 = torch.aten.add.Scalar %4543, %int15_5445, %int1_5446 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4544, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_5447 = torch.constant.int 2 + %4545 = torch.aten.mul.Scalar %4544, %int2_5447 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4545, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_5448 = torch.constant.int 1 + %int1_5449 = torch.constant.int 1 + %4546 = torch.aten.add.Scalar %4545, %int1_5448, %int1_5449 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4546, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %4547 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %4548 = torch.aten.view %4546, %4547 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %4548, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_5450 = torch.constant.int 4 + %int32_5451 = torch.constant.int 32 + %int8_5452 = torch.constant.int 8 + %int128_5453 = torch.constant.int 128 + %4549 = torch.prim.ListConstruct %int4_5450, %296, %int32_5451, %int8_5452, %int128_5453 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4550 = torch.aten.view %4390, %4549 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4550, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int32_5454 = torch.constant.int 32 - %int2_5455 = torch.constant.int 2 - %int32_5456 = torch.constant.int 32 - %int8_5457 = torch.constant.int 8 - %int128_5458 = torch.constant.int 128 - %4550 = torch.prim.ListConstruct %389, %int32_5454, %int2_5455, %int32_5456, %int8_5457, %int128_5458 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4551 = torch.aten.view %4549, %4550 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4551, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5459 = torch.constant.int 32 - %int8_5460 = torch.constant.int 8 - %int128_5461 = torch.constant.int 128 - %4552 = torch.prim.ListConstruct %4541, %int32_5459, %int8_5460, %int128_5461 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4553 = torch.aten.view %4551, %4552 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4553, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_5462 = torch.constant.int 4 - %int32_5463 = torch.constant.int 32 - %int8_5464 = torch.constant.int 8 + %int8_5455 = torch.constant.int 8 + %int128_5456 = torch.constant.int 128 + %4551 = torch.prim.ListConstruct %504, %int32_5454, %int8_5455, %int128_5456 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4552 = torch.aten.view %4550, %4551 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %4552, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_5457 = torch.constant.int 1 + %int2_5458 = torch.constant.int 2 + %4553 = torch.aten.transpose.int %4552, %int1_5457, %int2_5458 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4553, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_5459 = torch.constant.int 5 + %4554 = torch.prims.convert_element_type %4553, %int5_5459 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4554, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %4555 = torch.prim.ListConstruct %4548 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_5460 = torch.constant.bool false + %4556 = torch.aten.index_put %4542, %4555, %4554, %false_5460 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4556, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_5461 = torch.constant.int 32 + %int2_5462 = torch.constant.int 2 + %int8_5463 = torch.constant.int 8 + %int32_5464 = torch.constant.int 32 %int128_5465 = torch.constant.int 128 - %4554 = torch.prim.ListConstruct %int4_5462, %398, %int32_5463, %int8_5464, %int128_5465 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4555 = torch.aten.view %4471, %4554 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4555, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_5466 = torch.constant.int 4 - %4556 = torch.aten.mul.int %int4_5466, %398 : !torch.int, !torch.int -> !torch.int - %int32_5467 = torch.constant.int 32 - %int8_5468 = torch.constant.int 8 - %int128_5469 = torch.constant.int 128 - %4557 = torch.prim.ListConstruct %4556, %int32_5467, %int8_5468, %int128_5469 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4558 = torch.aten.view %4555, %4557 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4558, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_5470 = torch.constant.int 1 - %int1_5471 = torch.constant.int 1 - %4559 = torch.aten.add.Scalar %4529, %int1_5470, %int1_5471 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4559, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_5472 = torch.constant.int 4 - %4560 = torch.aten.mul.int %int4_5472, %398 : !torch.int, !torch.int -> !torch.int - %4561 = torch.prim.ListConstruct %4560 : (!torch.int) -> !torch.list - %4562 = torch.aten.view %4559, %4561 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4562, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %4563 = torch.prim.ListConstruct %4562 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_5473 = torch.constant.bool false - %4564 = torch.aten.index_put %4553, %4563, %4558, %false_5473 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4564, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_5474 = torch.constant.int 32 - %int2_5475 = torch.constant.int 2 - %int32_5476 = torch.constant.int 32 - %int8_5477 = torch.constant.int 8 - %int128_5478 = torch.constant.int 128 - %4565 = torch.prim.ListConstruct %389, %int32_5474, %int2_5475, %int32_5476, %int8_5477, %int128_5478 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4566 = torch.aten.view %4564, %4565 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4566, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5479 = torch.constant.int 2097152 - %4567 = torch.prim.ListConstruct %389, %int2097152_5479 : (!torch.int, !torch.int) -> !torch.list - %4568 = torch.aten.view %4566, %4567 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4568, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_5480 = torch.constant.int -2 - %4569 = torch.aten.unsqueeze %4527, %int-2_5480 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4569, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_5481 = torch.constant.int 4 - %int8_5482 = torch.constant.int 8 - %int4_5483 = torch.constant.int 4 - %int128_5484 = torch.constant.int 128 - %4570 = torch.prim.ListConstruct %int4_5481, %4512, %int8_5482, %int4_5483, %int128_5484 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_5485 = torch.constant.bool false - %4571 = torch.aten.expand %4569, %4570, %false_5485 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4571, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_5486 = torch.constant.int 0 - %4572 = torch.aten.clone %4571, %int0_5486 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4572, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_5487 = torch.constant.int 4 - %int32_5488 = torch.constant.int 32 - %int128_5489 = torch.constant.int 128 - %4573 = torch.prim.ListConstruct %int4_5487, %4512, %int32_5488, %int128_5489 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4574 = torch.aten._unsafe_view %4572, %4573 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4574, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_5490 = torch.constant.int -2 - %4575 = torch.aten.unsqueeze %4471, %int-2_5490 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4575, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %4557 = torch.prim.ListConstruct %297, %int32_5461, %int2_5462, %int8_5463, %int32_5464, %int128_5465 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4558 = torch.aten.view %4556, %4557 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4558, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_5466 = torch.constant.int 2097152 + %4559 = torch.prim.ListConstruct %297, %int2097152_5466 : (!torch.int, !torch.int) -> !torch.list + %4560 = torch.aten.view %4558, %4559 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4560, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_5467 = torch.constant.int -2 + %4561 = torch.aten.unsqueeze %4516, %int-2_5467 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4561, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_5468 = torch.constant.int 4 + %int8_5469 = torch.constant.int 8 + %int4_5470 = torch.constant.int 4 + %int128_5471 = torch.constant.int 128 + %4562 = torch.prim.ListConstruct %int4_5468, %298, %int8_5469, %int4_5470, %int128_5471 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_5472 = torch.constant.bool false + %4563 = torch.aten.expand %4561, %4562, %false_5472 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4563, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_5473 = torch.constant.int 0 + %4564 = torch.aten.clone %4563, %int0_5473 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4564, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_5474 = torch.constant.int 4 + %int32_5475 = torch.constant.int 32 + %int128_5476 = torch.constant.int 128 + %4565 = torch.prim.ListConstruct %int4_5474, %298, %int32_5475, %int128_5476 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4566 = torch.aten._unsafe_view %4564, %4565 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4566, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_5477 = torch.constant.int -2 + %4567 = torch.aten.unsqueeze %4390, %int-2_5477 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4567, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_5478 = torch.constant.int 4 + %int8_5479 = torch.constant.int 8 + %int4_5480 = torch.constant.int 4 + %int128_5481 = torch.constant.int 128 + %4568 = torch.prim.ListConstruct %int4_5478, %298, %int8_5479, %int4_5480, %int128_5481 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_5482 = torch.constant.bool false + %4569 = torch.aten.expand %4567, %4568, %false_5482 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4569, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_5483 = torch.constant.int 0 + %4570 = torch.aten.clone %4569, %int0_5483 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4570, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_5484 = torch.constant.int 4 + %int32_5485 = torch.constant.int 32 + %int128_5486 = torch.constant.int 128 + %4571 = torch.prim.ListConstruct %int4_5484, %298, %int32_5485, %int128_5486 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4572 = torch.aten._unsafe_view %4570, %4571 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4572, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_5487 = torch.constant.int 1 + %int2_5488 = torch.constant.int 2 + %4573 = torch.aten.transpose.int %4453, %int1_5487, %int2_5488 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4573, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_5489 = torch.constant.int 1 + %int2_5490 = torch.constant.int 2 + %4574 = torch.aten.transpose.int %4566, %int1_5489, %int2_5490 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4574, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_5491 = torch.constant.int 1 - %4576 = torch.aten.size.int %4465, %int1_5491 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_5492 = torch.constant.int 4 - %int8_5493 = torch.constant.int 8 - %int4_5494 = torch.constant.int 4 - %int128_5495 = torch.constant.int 128 - %4577 = torch.prim.ListConstruct %int4_5492, %4576, %int8_5493, %int4_5494, %int128_5495 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_5496 = torch.constant.bool false - %4578 = torch.aten.expand %4575, %4577, %false_5496 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4578, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_5497 = torch.constant.int 0 - %4579 = torch.aten.clone %4578, %int0_5497 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4579, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int2_5492 = torch.constant.int 2 + %4575 = torch.aten.transpose.int %4572, %int1_5491, %int2_5492 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4575, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_5493 = torch.constant.float 0.000000e+00 + %false_5494 = torch.constant.bool false + %none_5495 = torch.constant.none + %4576:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4573, %4574, %4575, %float0.000000e00_5493, %false_5494, %327, %none_5495) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %4576#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_5496 = torch.constant.int 1 + %int2_5497 = torch.constant.int 2 + %4577 = torch.aten.transpose.int %4576#0, %int1_5496, %int2_5497 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4577, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int4_5498 = torch.constant.int 4 - %int32_5499 = torch.constant.int 32 - %int128_5500 = torch.constant.int 128 - %4580 = torch.prim.ListConstruct %int4_5498, %4576, %int32_5499, %int128_5500 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4581 = torch.aten._unsafe_view %4579, %4580 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4581, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_5501 = torch.constant.int 1 - %int2_5502 = torch.constant.int 2 - %4582 = torch.aten.transpose.int %4499, %int1_5501, %int2_5502 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4582, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_5503 = torch.constant.int 1 - %int2_5504 = torch.constant.int 2 - %4583 = torch.aten.transpose.int %4574, %int1_5503, %int2_5504 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4583, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_5505 = torch.constant.int 1 - %int2_5506 = torch.constant.int 2 - %4584 = torch.aten.transpose.int %4581, %int1_5505, %int2_5506 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4584, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_5507 = torch.constant.float 0.000000e+00 - %true_5508 = torch.constant.bool true - %none_5509 = torch.constant.none - %none_5510 = torch.constant.none - %4585:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4582, %4583, %4584, %float0.000000e00_5507, %true_5508, %none_5509, %none_5510) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %4585#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_5511 = torch.constant.int 1 - %int2_5512 = torch.constant.int 2 - %4586 = torch.aten.transpose.int %4585#0, %int1_5511, %int2_5512 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4586, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_5513 = torch.constant.int 4 - %int4096_5514 = torch.constant.int 4096 - %4587 = torch.prim.ListConstruct %int4_5513, %4484, %int4096_5514 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4588 = torch.aten.view %4586, %4587 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4588, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5515 = torch.constant.int -2 - %int-1_5516 = torch.constant.int -1 - %4589 = torch.aten.transpose.int %194, %int-2_5515, %int-1_5516 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_5517 = torch.constant.int 4 - %4590 = torch.aten.mul.int %int4_5517, %4484 : !torch.int, !torch.int -> !torch.int - %int4096_5518 = torch.constant.int 4096 - %4591 = torch.prim.ListConstruct %4590, %int4096_5518 : (!torch.int, !torch.int) -> !torch.list - %4592 = torch.aten.view %4588, %4591 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4592, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4593 = torch.aten.mm %4592, %4589 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4593, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_5519 = torch.constant.int 4 - %int4096_5520 = torch.constant.int 4096 - %4594 = torch.prim.ListConstruct %int4_5519, %4484, %int4096_5520 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4595 = torch.aten.view %4593, %4594 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4595, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_5521 = torch.constant.int 1 - %4596 = torch.aten.add.Tensor %4434, %4595, %int1_5521 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4596, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_5522 = torch.constant.int 6 - %4597 = torch.prims.convert_element_type %4596, %int6_5522 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4597, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_5523 = torch.constant.int 2 - %4598 = torch.aten.pow.Tensor_Scalar %4597, %int2_5523 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4598, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_5524 = torch.constant.int -1 - %4599 = torch.prim.ListConstruct %int-1_5524 : (!torch.int) -> !torch.list - %true_5525 = torch.constant.bool true - %none_5526 = torch.constant.none - %4600 = torch.aten.mean.dim %4598, %4599, %true_5525, %none_5526 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4600, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_5527 = torch.constant.float 9.9999997473787516E-6 - %int1_5528 = torch.constant.int 1 - %4601 = torch.aten.add.Scalar %4600, %float9.999990e-06_5527, %int1_5528 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4601, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4602 = torch.aten.rsqrt %4601 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4602, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4603 = torch.aten.mul.Tensor %4597, %4602 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4603, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_5529 = torch.constant.int 5 - %4604 = torch.prims.convert_element_type %4603, %int5_5529 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4604, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %4605 = torch.aten.mul.Tensor %195, %4604 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4605, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int4096_5499 = torch.constant.int 4096 + %4578 = torch.prim.ListConstruct %int4_5498, %298, %int4096_5499 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4579 = torch.aten.view %4577, %4578 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4579, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_5500 = torch.constant.int -2 + %int-1_5501 = torch.constant.int -1 + %4580 = torch.aten.transpose.int %141, %int-2_5500, %int-1_5501 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_5502 = torch.constant.int 5 + %4581 = torch.prims.convert_element_type %4580, %int5_5502 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_5503 = torch.constant.int 4096 + %4582 = torch.prim.ListConstruct %342, %int4096_5503 : (!torch.int, !torch.int) -> !torch.list + %4583 = torch.aten.view %4579, %4582 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4583, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4584 = torch.aten.mm %4583, %4581 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4584, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_5504 = torch.constant.int 4 + %int4096_5505 = torch.constant.int 4096 + %4585 = torch.prim.ListConstruct %int4_5504, %298, %int4096_5505 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4586 = torch.aten.view %4584, %4585 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4586, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_5506 = torch.constant.int 1 + %4587 = torch.aten.add.Tensor %4353, %4586, %int1_5506 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4587, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_5507 = torch.constant.int 6 + %4588 = torch.prims.convert_element_type %4587, %int6_5507 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4588, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_5508 = torch.constant.int 2 + %4589 = torch.aten.pow.Tensor_Scalar %4588, %int2_5508 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4589, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_5509 = torch.constant.int -1 + %4590 = torch.prim.ListConstruct %int-1_5509 : (!torch.int) -> !torch.list + %true_5510 = torch.constant.bool true + %none_5511 = torch.constant.none + %4591 = torch.aten.mean.dim %4589, %4590, %true_5510, %none_5511 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4591, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_5512 = torch.constant.float 9.9999997473787516E-6 + %int1_5513 = torch.constant.int 1 + %4592 = torch.aten.add.Scalar %4591, %float9.999990e-06_5512, %int1_5513 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4592, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4593 = torch.aten.rsqrt %4592 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4593, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4594 = torch.aten.mul.Tensor %4588, %4593 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4594, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_5514 = torch.constant.int 5 + %4595 = torch.prims.convert_element_type %4594, %int5_5514 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4595, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %4596 = torch.aten.mul.Tensor %142, %4595 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4596, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_5515 = torch.constant.int 5 + %4597 = torch.prims.convert_element_type %4596, %int5_5515 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4597, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_5516 = torch.constant.int -2 + %int-1_5517 = torch.constant.int -1 + %4598 = torch.aten.transpose.int %143, %int-2_5516, %int-1_5517 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_5518 = torch.constant.int 5 + %4599 = torch.prims.convert_element_type %4598, %int5_5518 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_5519 = torch.constant.int 4096 + %4600 = torch.prim.ListConstruct %342, %int4096_5519 : (!torch.int, !torch.int) -> !torch.list + %4601 = torch.aten.view %4597, %4600 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4601, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4602 = torch.aten.mm %4601, %4599 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %4602, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_5520 = torch.constant.int 4 + %int14336_5521 = torch.constant.int 14336 + %4603 = torch.prim.ListConstruct %int4_5520, %298, %int14336_5521 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4604 = torch.aten.view %4602, %4603 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4604, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %4605 = torch.aten.silu %4604 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4605, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_5522 = torch.constant.int -2 + %int-1_5523 = torch.constant.int -1 + %4606 = torch.aten.transpose.int %144, %int-2_5522, %int-1_5523 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_5524 = torch.constant.int 5 + %4607 = torch.prims.convert_element_type %4606, %int5_5524 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_5525 = torch.constant.int 4096 + %4608 = torch.prim.ListConstruct %342, %int4096_5525 : (!torch.int, !torch.int) -> !torch.list + %4609 = torch.aten.view %4597, %4608 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4609, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4610 = torch.aten.mm %4609, %4607 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %4610, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_5526 = torch.constant.int 4 + %int14336_5527 = torch.constant.int 14336 + %4611 = torch.prim.ListConstruct %int4_5526, %298, %int14336_5527 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4612 = torch.aten.view %4610, %4611 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4612, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %4613 = torch.aten.mul.Tensor %4605, %4612 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4613, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_5528 = torch.constant.int -2 + %int-1_5529 = torch.constant.int -1 + %4614 = torch.aten.transpose.int %145, %int-2_5528, %int-1_5529 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> %int5_5530 = torch.constant.int 5 - %4606 = torch.prims.convert_element_type %4605, %int5_5530 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4606, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5531 = torch.constant.int -2 - %int-1_5532 = torch.constant.int -1 - %4607 = torch.aten.transpose.int %196, %int-2_5531, %int-1_5532 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_5533 = torch.constant.int 4 - %4608 = torch.aten.mul.int %int4_5533, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5534 = torch.constant.int 4096 - %4609 = torch.prim.ListConstruct %4608, %int4096_5534 : (!torch.int, !torch.int) -> !torch.list - %4610 = torch.aten.view %4606, %4609 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4610, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4611 = torch.aten.mm %4610, %4607 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4611, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_5535 = torch.constant.int 4 - %int14336_5536 = torch.constant.int 14336 - %4612 = torch.prim.ListConstruct %int4_5535, %306, %int14336_5536 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4613 = torch.aten.view %4611, %4612 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4613, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %4614 = torch.aten.silu %4613 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4614, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_5537 = torch.constant.int -2 - %int-1_5538 = torch.constant.int -1 - %4615 = torch.aten.transpose.int %197, %int-2_5537, %int-1_5538 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_5539 = torch.constant.int 4 - %4616 = torch.aten.mul.int %int4_5539, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5540 = torch.constant.int 4096 - %4617 = torch.prim.ListConstruct %4616, %int4096_5540 : (!torch.int, !torch.int) -> !torch.list - %4618 = torch.aten.view %4606, %4617 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4618, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4619 = torch.aten.mm %4618, %4615 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4619, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_5541 = torch.constant.int 4 - %int14336_5542 = torch.constant.int 14336 - %4620 = torch.prim.ListConstruct %int4_5541, %306, %int14336_5542 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4621 = torch.aten.view %4619, %4620 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4621, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %4622 = torch.aten.mul.Tensor %4614, %4621 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4622, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_5543 = torch.constant.int -2 - %int-1_5544 = torch.constant.int -1 - %4623 = torch.aten.transpose.int %198, %int-2_5543, %int-1_5544 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_5545 = torch.constant.int 1 - %4624 = torch.aten.size.int %4613, %int1_5545 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_5546 = torch.constant.int 4 - %4625 = torch.aten.mul.int %int4_5546, %4624 : !torch.int, !torch.int -> !torch.int - %int14336_5547 = torch.constant.int 14336 - %4626 = torch.prim.ListConstruct %4625, %int14336_5547 : (!torch.int, !torch.int) -> !torch.list - %4627 = torch.aten.view %4622, %4626 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4627, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %4628 = torch.aten.mm %4627, %4623 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4628, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4615 = torch.prims.convert_element_type %4614, %int5_5530 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_5531 = torch.constant.int 14336 + %4616 = torch.prim.ListConstruct %342, %int14336_5531 : (!torch.int, !torch.int) -> !torch.list + %4617 = torch.aten.view %4613, %4616 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %4617, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %4618 = torch.aten.mm %4617, %4615 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4618, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_5532 = torch.constant.int 4 + %int4096_5533 = torch.constant.int 4096 + %4619 = torch.prim.ListConstruct %int4_5532, %298, %int4096_5533 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4620 = torch.aten.view %4618, %4619 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4620, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_5534 = torch.constant.int 1 + %4621 = torch.aten.add.Tensor %4587, %4620, %int1_5534 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4621, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_5535 = torch.constant.int 6 + %4622 = torch.prims.convert_element_type %4621, %int6_5535 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4622, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_5536 = torch.constant.int 2 + %4623 = torch.aten.pow.Tensor_Scalar %4622, %int2_5536 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4623, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_5537 = torch.constant.int -1 + %4624 = torch.prim.ListConstruct %int-1_5537 : (!torch.int) -> !torch.list + %true_5538 = torch.constant.bool true + %none_5539 = torch.constant.none + %4625 = torch.aten.mean.dim %4623, %4624, %true_5538, %none_5539 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4625, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_5540 = torch.constant.float 9.9999997473787516E-6 + %int1_5541 = torch.constant.int 1 + %4626 = torch.aten.add.Scalar %4625, %float9.999990e-06_5540, %int1_5541 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4626, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4627 = torch.aten.rsqrt %4626 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4627, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4628 = torch.aten.mul.Tensor %4622, %4627 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4628, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_5542 = torch.constant.int 5 + %4629 = torch.prims.convert_element_type %4628, %int5_5542 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4629, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %4630 = torch.aten.mul.Tensor %146, %4629 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4630, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_5543 = torch.constant.int 5 + %4631 = torch.prims.convert_element_type %4630, %int5_5543 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4631, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_5544 = torch.constant.int -2 + %int-1_5545 = torch.constant.int -1 + %4632 = torch.aten.transpose.int %147, %int-2_5544, %int-1_5545 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_5546 = torch.constant.int 5 + %4633 = torch.prims.convert_element_type %4632, %int5_5546 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_5547 = torch.constant.int 4096 + %4634 = torch.prim.ListConstruct %342, %int4096_5547 : (!torch.int, !torch.int) -> !torch.list + %4635 = torch.aten.view %4631, %4634 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4635, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4636 = torch.aten.mm %4635, %4633 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4636, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> %int4_5548 = torch.constant.int 4 %int4096_5549 = torch.constant.int 4096 - %4629 = torch.prim.ListConstruct %int4_5548, %4624, %int4096_5549 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4630 = torch.aten.view %4628, %4629 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4630, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_5550 = torch.constant.int 1 - %4631 = torch.aten.add.Tensor %4596, %4630, %int1_5550 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4631, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_5551 = torch.constant.int 6 - %4632 = torch.prims.convert_element_type %4631, %int6_5551 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4632, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_5552 = torch.constant.int 2 - %4633 = torch.aten.pow.Tensor_Scalar %4632, %int2_5552 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4633, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_5553 = torch.constant.int -1 - %4634 = torch.prim.ListConstruct %int-1_5553 : (!torch.int) -> !torch.list - %true_5554 = torch.constant.bool true - %none_5555 = torch.constant.none - %4635 = torch.aten.mean.dim %4633, %4634, %true_5554, %none_5555 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4635, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_5556 = torch.constant.float 9.9999997473787516E-6 - %int1_5557 = torch.constant.int 1 - %4636 = torch.aten.add.Scalar %4635, %float9.999990e-06_5556, %int1_5557 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4636, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4637 = torch.aten.rsqrt %4636 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4637, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4638 = torch.aten.mul.Tensor %4632, %4637 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4638, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %4637 = torch.prim.ListConstruct %int4_5548, %298, %int4096_5549 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4638 = torch.aten.view %4636, %4637 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4638, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_5550 = torch.constant.int -2 + %int-1_5551 = torch.constant.int -1 + %4639 = torch.aten.transpose.int %148, %int-2_5550, %int-1_5551 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_5552 = torch.constant.int 5 + %4640 = torch.prims.convert_element_type %4639, %int5_5552 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_5553 = torch.constant.int 4096 + %4641 = torch.prim.ListConstruct %342, %int4096_5553 : (!torch.int, !torch.int) -> !torch.list + %4642 = torch.aten.view %4631, %4641 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4642, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4643 = torch.aten.mm %4642, %4640 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %4643, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_5554 = torch.constant.int 4 + %int1024_5555 = torch.constant.int 1024 + %4644 = torch.prim.ListConstruct %int4_5554, %298, %int1024_5555 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4645 = torch.aten.view %4643, %4644 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %4645, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_5556 = torch.constant.int -2 + %int-1_5557 = torch.constant.int -1 + %4646 = torch.aten.transpose.int %149, %int-2_5556, %int-1_5557 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> %int5_5558 = torch.constant.int 5 - %4639 = torch.prims.convert_element_type %4638, %int5_5558 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4639, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %4640 = torch.aten.mul.Tensor %199, %4639 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4640, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_5559 = torch.constant.int 5 - %4641 = torch.prims.convert_element_type %4640, %int5_5559 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4641, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5560 = torch.constant.int -2 - %int-1_5561 = torch.constant.int -1 - %4642 = torch.aten.transpose.int %200, %int-2_5560, %int-1_5561 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %4647 = torch.prims.convert_element_type %4646, %int5_5558 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_5559 = torch.constant.int 4096 + %4648 = torch.prim.ListConstruct %342, %int4096_5559 : (!torch.int, !torch.int) -> !torch.list + %4649 = torch.aten.view %4631, %4648 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4649, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4650 = torch.aten.mm %4649, %4647 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %4650, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_5560 = torch.constant.int 4 + %int1024_5561 = torch.constant.int 1024 + %4651 = torch.prim.ListConstruct %int4_5560, %298, %int1024_5561 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4652 = torch.aten.view %4650, %4651 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %4652, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> %int4_5562 = torch.constant.int 4 - %4643 = torch.aten.mul.int %int4_5562, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5563 = torch.constant.int 4096 - %4644 = torch.prim.ListConstruct %4643, %int4096_5563 : (!torch.int, !torch.int) -> !torch.list - %4645 = torch.aten.view %4641, %4644 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4645, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4646 = torch.aten.mm %4645, %4642 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4646, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_5564 = torch.constant.int 4 - %int4096_5565 = torch.constant.int 4096 - %4647 = torch.prim.ListConstruct %int4_5564, %306, %int4096_5565 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4648 = torch.aten.view %4646, %4647 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4648, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5566 = torch.constant.int -2 - %int-1_5567 = torch.constant.int -1 - %4649 = torch.aten.transpose.int %201, %int-2_5566, %int-1_5567 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int32_5563 = torch.constant.int 32 + %int128_5564 = torch.constant.int 128 + %4653 = torch.prim.ListConstruct %int4_5562, %298, %int32_5563, %int128_5564 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4654 = torch.aten.view %4638, %4653 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4654, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_5565 = torch.constant.int 4 + %int8_5566 = torch.constant.int 8 + %int128_5567 = torch.constant.int 128 + %4655 = torch.prim.ListConstruct %int4_5565, %298, %int8_5566, %int128_5567 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4656 = torch.aten.view %4645, %4655 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4656, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> %int4_5568 = torch.constant.int 4 - %4650 = torch.aten.mul.int %int4_5568, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5569 = torch.constant.int 4096 - %4651 = torch.prim.ListConstruct %4650, %int4096_5569 : (!torch.int, !torch.int) -> !torch.list - %4652 = torch.aten.view %4641, %4651 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4652, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4653 = torch.aten.mm %4652, %4649 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %4653, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_5570 = torch.constant.int 4 - %int1024_5571 = torch.constant.int 1024 - %4654 = torch.prim.ListConstruct %int4_5570, %306, %int1024_5571 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4655 = torch.aten.view %4653, %4654 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %4655, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_5572 = torch.constant.int -2 - %int-1_5573 = torch.constant.int -1 - %4656 = torch.aten.transpose.int %202, %int-2_5572, %int-1_5573 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_5574 = torch.constant.int 4 - %4657 = torch.aten.mul.int %int4_5574, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5575 = torch.constant.int 4096 - %4658 = torch.prim.ListConstruct %4657, %int4096_5575 : (!torch.int, !torch.int) -> !torch.list - %4659 = torch.aten.view %4641, %4658 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4659, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4660 = torch.aten.mm %4659, %4656 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %4660, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_5576 = torch.constant.int 4 - %int1024_5577 = torch.constant.int 1024 - %4661 = torch.prim.ListConstruct %int4_5576, %306, %int1024_5577 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4662 = torch.aten.view %4660, %4661 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %4662, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_5578 = torch.constant.int 4 - %int32_5579 = torch.constant.int 32 - %int128_5580 = torch.constant.int 128 - %4663 = torch.prim.ListConstruct %int4_5578, %306, %int32_5579, %int128_5580 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4664 = torch.aten.view %4648, %4663 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4664, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_5581 = torch.constant.int 4 - %int8_5582 = torch.constant.int 8 - %int128_5583 = torch.constant.int 128 - %4665 = torch.prim.ListConstruct %int4_5581, %306, %int8_5582, %int128_5583 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4666 = torch.aten.view %4655, %4665 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4666, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_5584 = torch.constant.int 4 - %int8_5585 = torch.constant.int 8 - %int128_5586 = torch.constant.int 128 - %4667 = torch.prim.ListConstruct %int4_5584, %306, %int8_5585, %int128_5586 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4668 = torch.aten.view %4662, %4667 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4668, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_5587 = torch.constant.int 131072 - %none_5588 = torch.constant.none - %none_5589 = torch.constant.none - %cpu_5590 = torch.constant.device "cpu" - %false_5591 = torch.constant.bool false - %4669 = torch.aten.arange %int131072_5587, %none_5588, %none_5589, %cpu_5590, %false_5591 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_5592 = torch.constant.int 0 - %int128_5593 = torch.constant.int 128 - %none_5594 = torch.constant.none - %none_5595 = torch.constant.none - %cpu_5596 = torch.constant.device "cpu" - %false_5597 = torch.constant.bool false - %4670 = torch.aten.arange.start %int0_5592, %int128_5593, %none_5594, %none_5595, %cpu_5596, %false_5597 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_5598 = torch.constant.int 2 - %4671 = torch.aten.floor_divide.Scalar %4670, %int2_5598 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_5599 = torch.constant.int 6 - %4672 = torch.prims.convert_element_type %4671, %int6_5599 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_5600 = torch.constant.int 128 - %4673 = torch.aten.div.Scalar %4672, %int128_5600 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_5601 = torch.constant.float 2.000000e+00 - %4674 = torch.aten.mul.Scalar %4673, %float2.000000e00_5601 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_5602 = torch.constant.float 5.000000e+05 - %4675 = torch.aten.pow.Scalar %float5.000000e05_5602, %4674 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %4676 = torch.aten.reciprocal %4675 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_5603 = torch.constant.float 1.000000e+00 - %4677 = torch.aten.mul.Scalar %4676, %float1.000000e00_5603 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_5604 = torch.constant.int 1 - %4678 = torch.aten.unsqueeze %4669, %int1_5604 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_5605 = torch.constant.int 0 - %4679 = torch.aten.unsqueeze %4677, %int0_5605 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %4680 = torch.aten.mul.Tensor %4678, %4679 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_5606 = torch.constant.int 1 - %4681 = torch.aten.size.int %4648, %int1_5606 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_5607 = torch.constant.int 0 - %4682 = torch.aten.add.int %int0_5607, %4681 : !torch.int, !torch.int -> !torch.int + %int8_5569 = torch.constant.int 8 + %int128_5570 = torch.constant.int 128 + %4657 = torch.prim.ListConstruct %int4_5568, %298, %int8_5569, %int128_5570 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4658 = torch.aten.view %4652, %4657 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4658, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_5571 = torch.constant.int 131072 + %none_5572 = torch.constant.none + %none_5573 = torch.constant.none + %cpu_5574 = torch.constant.device "cpu" + %false_5575 = torch.constant.bool false + %4659 = torch.aten.arange %int131072_5571, %none_5572, %none_5573, %cpu_5574, %false_5575 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_5576 = torch.constant.int 0 + %int128_5577 = torch.constant.int 128 + %int2_5578 = torch.constant.int 2 + %int4_5579 = torch.constant.int 4 + %none_5580 = torch.constant.none + %cpu_5581 = torch.constant.device "cpu" + %false_5582 = torch.constant.bool false + %4660 = torch.aten.arange.start_step %int0_5576, %int128_5577, %int2_5578, %int4_5579, %none_5580, %cpu_5581, %false_5582 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_5583 = torch.constant.int 6 + %4661 = torch.prims.convert_element_type %4660, %int6_5583 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_5584 = torch.constant.int 128 + %4662 = torch.aten.div.Scalar %4661, %int128_5584 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_5585 = torch.constant.float 5.000000e+05 + %4663 = torch.aten.pow.Scalar %float5.000000e05_5585, %4662 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4664 = torch.aten.reciprocal %4663 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_5586 = torch.constant.float 1.000000e+00 + %4665 = torch.aten.mul.Scalar %4664, %float1.000000e00_5586 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %4666 = torch.aten.reciprocal %4665 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_5587 = torch.constant.float 6.2831853071795862 + %4667 = torch.aten.mul.Scalar %4666, %float6.283190e00_5587 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_5588 = torch.constant.float 8.192000e+03 + %4668 = torch.aten.gt.Scalar %4667, %float8.192000e03_5588 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_5589 = torch.constant.int 8 + %4669 = torch.aten.div.Scalar %4665, %int8_5589 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %4670 = torch.aten.where.self %4668, %4669, %4665 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4671 = torch.aten.reciprocal %4667 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_5590 = torch.constant.int 8192 + %4672 = torch.aten.mul.Scalar %4671, %int8192_5590 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_5591 = torch.constant.int 1 + %int1_5592 = torch.constant.int 1 + %4673 = torch.aten.sub.Scalar %4672, %int1_5591, %int1_5592 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_5593 = torch.constant.int 3 + %4674 = torch.aten.div.Scalar %4673, %int3_5593 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_5594 = torch.constant.int 1 + %int1_5595 = torch.constant.int 1 + %4675 = torch.aten.rsub.Scalar %4674, %int1_5594, %int1_5595 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %4676 = torch.aten.mul.Tensor %4675, %4670 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_5596 = torch.constant.int 8 + %4677 = torch.aten.div.Scalar %4676, %int8_5596 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %4678 = torch.aten.mul.Tensor %4674, %4670 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_5597 = torch.constant.int 1 + %4679 = torch.aten.add.Tensor %4677, %4678, %int1_5597 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_5598 = torch.constant.float 2.048000e+03 + %4680 = torch.aten.lt.Scalar %4667, %float2.048000e03_5598 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %4681 = torch.aten.bitwise_not %4680 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_5599 = torch.constant.float 8.192000e+03 + %4682 = torch.aten.gt.Scalar %4667, %float8.192000e03_5599 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %4683 = torch.aten.bitwise_not %4682 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %4684 = torch.aten.mul.Tensor %4681, %4683 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %4685 = torch.aten.where.self %4684, %4679, %4670 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4686 = torch.prim.ListConstruct %4685, %4685 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_5600 = torch.constant.int -1 + %4687 = torch.aten.cat %4686, %int-1_5600 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_5601 = torch.constant.int 6 + %4688 = torch.prims.convert_element_type %4687, %int6_5601 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_5602 = torch.constant.int 1 + %4689 = torch.aten.unsqueeze %4659, %int1_5602 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_5603 = torch.constant.int 6 + %4690 = torch.prims.convert_element_type %4689, %int6_5603 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_5604 = torch.constant.int 0 + %4691 = torch.aten.unsqueeze %4688, %int0_5604 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_5605 = torch.constant.int 6 + %4692 = torch.prims.convert_element_type %4691, %int6_5605 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %4693 = torch.aten.mul.Tensor %4690, %4692 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %4694 = torch.aten.cos %4693 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_5606 = torch.constant.int 5 + %4695 = torch.prims.convert_element_type %4694, %int5_5606 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %4696 = torch.aten.sin %4693 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_5607 = torch.constant.int 5 + %4697 = torch.prims.convert_element_type %4696, %int5_5607 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> %int0_5608 = torch.constant.int 0 %int0_5609 = torch.constant.int 0 %int1_5610 = torch.constant.int 1 - %4683 = torch.aten.slice.Tensor %4680, %int0_5608, %int0_5609, %4682, %int1_5610 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4683, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %4698 = torch.aten.slice.Tensor %4695, %int0_5608, %int0_5609, %298, %int1_5610 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4698, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_5611 = torch.constant.int 1 %int0_5612 = torch.constant.int 0 %int9223372036854775807_5613 = torch.constant.int 9223372036854775807 %int1_5614 = torch.constant.int 1 - %4684 = torch.aten.slice.Tensor %4683, %int1_5611, %int0_5612, %int9223372036854775807_5613, %int1_5614 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4684, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_5615 = torch.constant.int 1 + %4699 = torch.aten.slice.Tensor %4698, %int1_5611, %int0_5612, %int9223372036854775807_5613, %int1_5614 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4699, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_5615 = torch.constant.int 0 %int0_5616 = torch.constant.int 0 - %int9223372036854775807_5617 = torch.constant.int 9223372036854775807 + %int1_5617 = torch.constant.int 1 + %4700 = torch.aten.slice.Tensor %4697, %int0_5615, %int0_5616, %298, %int1_5617 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4700, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_5618 = torch.constant.int 1 - %4685 = torch.aten.slice.Tensor %4684, %int1_5615, %int0_5616, %int9223372036854775807_5617, %int1_5618 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4685, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> %int0_5619 = torch.constant.int 0 - %4686 = torch.aten.unsqueeze %4685, %int0_5619 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4686, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_5620 = torch.constant.int 1 - %int0_5621 = torch.constant.int 0 - %int9223372036854775807_5622 = torch.constant.int 9223372036854775807 + %int9223372036854775807_5620 = torch.constant.int 9223372036854775807 + %int1_5621 = torch.constant.int 1 + %4701 = torch.aten.slice.Tensor %4700, %int1_5618, %int0_5619, %int9223372036854775807_5620, %int1_5621 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4701, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_5622 = torch.constant.int 0 + %4702 = torch.aten.unsqueeze %4699, %int0_5622 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4702, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int1_5623 = torch.constant.int 1 - %4687 = torch.aten.slice.Tensor %4686, %int1_5620, %int0_5621, %int9223372036854775807_5622, %int1_5623 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4687, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_5624 = torch.constant.int 2 - %int0_5625 = torch.constant.int 0 - %int9223372036854775807_5626 = torch.constant.int 9223372036854775807 - %int1_5627 = torch.constant.int 1 - %4688 = torch.aten.slice.Tensor %4687, %int2_5624, %int0_5625, %int9223372036854775807_5626, %int1_5627 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4688, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_5628 = torch.constant.int 4 - %int1_5629 = torch.constant.int 1 - %int1_5630 = torch.constant.int 1 - %4689 = torch.prim.ListConstruct %int4_5628, %int1_5629, %int1_5630 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4690 = torch.aten.repeat %4688, %4689 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %4690, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_5631 = torch.constant.int 6 - %4691 = torch.prims.convert_element_type %4664, %int6_5631 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %4691, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %4692 = torch_c.to_builtin_tensor %4691 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %4693 = torch_c.to_builtin_tensor %4690 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %4694 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%4692, %4693) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %4695 = torch_c.from_builtin_tensor %4694 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %4695, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_5632 = torch.constant.int 5 - %4696 = torch.prims.convert_element_type %4695, %int5_5632 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4696, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_5633 = torch.constant.int 131072 - %none_5634 = torch.constant.none - %none_5635 = torch.constant.none - %cpu_5636 = torch.constant.device "cpu" - %false_5637 = torch.constant.bool false - %4697 = torch.aten.arange %int131072_5633, %none_5634, %none_5635, %cpu_5636, %false_5637 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_5624 = torch.constant.int 0 + %int9223372036854775807_5625 = torch.constant.int 9223372036854775807 + %int1_5626 = torch.constant.int 1 + %4703 = torch.aten.slice.Tensor %4702, %int1_5623, %int0_5624, %int9223372036854775807_5625, %int1_5626 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4703, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_5627 = torch.constant.int 2 + %4704 = torch.aten.unsqueeze %4703, %int2_5627 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4704, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_5628 = torch.constant.int 3 + %int0_5629 = torch.constant.int 0 + %int9223372036854775807_5630 = torch.constant.int 9223372036854775807 + %int1_5631 = torch.constant.int 1 + %4705 = torch.aten.slice.Tensor %4704, %int3_5628, %int0_5629, %int9223372036854775807_5630, %int1_5631 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4705, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_5632 = torch.constant.int 4 + %int1_5633 = torch.constant.int 1 + %int1_5634 = torch.constant.int 1 + %int1_5635 = torch.constant.int 1 + %4706 = torch.prim.ListConstruct %int4_5632, %int1_5633, %int1_5634, %int1_5635 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4707 = torch.aten.repeat %4705, %4706 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %4707, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_5636 = torch.constant.int 0 + %4708 = torch.aten.unsqueeze %4701, %int0_5636 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4708, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_5637 = torch.constant.int 1 %int0_5638 = torch.constant.int 0 - %int128_5639 = torch.constant.int 128 - %none_5640 = torch.constant.none - %none_5641 = torch.constant.none - %cpu_5642 = torch.constant.device "cpu" - %false_5643 = torch.constant.bool false - %4698 = torch.aten.arange.start %int0_5638, %int128_5639, %none_5640, %none_5641, %cpu_5642, %false_5643 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_5644 = torch.constant.int 2 - %4699 = torch.aten.floor_divide.Scalar %4698, %int2_5644 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_5645 = torch.constant.int 6 - %4700 = torch.prims.convert_element_type %4699, %int6_5645 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_5646 = torch.constant.int 128 - %4701 = torch.aten.div.Scalar %4700, %int128_5646 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_5647 = torch.constant.float 2.000000e+00 - %4702 = torch.aten.mul.Scalar %4701, %float2.000000e00_5647 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_5648 = torch.constant.float 5.000000e+05 - %4703 = torch.aten.pow.Scalar %float5.000000e05_5648, %4702 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %4704 = torch.aten.reciprocal %4703 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_5649 = torch.constant.float 1.000000e+00 - %4705 = torch.aten.mul.Scalar %4704, %float1.000000e00_5649 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_5650 = torch.constant.int 1 - %4706 = torch.aten.unsqueeze %4697, %int1_5650 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int9223372036854775807_5639 = torch.constant.int 9223372036854775807 + %int1_5640 = torch.constant.int 1 + %4709 = torch.aten.slice.Tensor %4708, %int1_5637, %int0_5638, %int9223372036854775807_5639, %int1_5640 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4709, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_5641 = torch.constant.int 2 + %4710 = torch.aten.unsqueeze %4709, %int2_5641 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4710, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_5642 = torch.constant.int 3 + %int0_5643 = torch.constant.int 0 + %int9223372036854775807_5644 = torch.constant.int 9223372036854775807 + %int1_5645 = torch.constant.int 1 + %4711 = torch.aten.slice.Tensor %4710, %int3_5642, %int0_5643, %int9223372036854775807_5644, %int1_5645 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4711, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_5646 = torch.constant.int 4 + %int1_5647 = torch.constant.int 1 + %int1_5648 = torch.constant.int 1 + %int1_5649 = torch.constant.int 1 + %4712 = torch.prim.ListConstruct %int4_5646, %int1_5647, %int1_5648, %int1_5649 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4713 = torch.aten.repeat %4711, %4712 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %4713, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %4714 = torch.aten.mul.Tensor %4654, %4707 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4714, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_5650 = torch.constant.int 3 %int0_5651 = torch.constant.int 0 - %4707 = torch.aten.unsqueeze %4705, %int0_5651 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %4708 = torch.aten.mul.Tensor %4706, %4707 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_5652 = torch.constant.int 1 - %4709 = torch.aten.size.int %4655, %int1_5652 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_5653 = torch.constant.int 0 - %4710 = torch.aten.add.int %int0_5653, %4709 : !torch.int, !torch.int -> !torch.int - %int0_5654 = torch.constant.int 0 - %int0_5655 = torch.constant.int 0 - %int1_5656 = torch.constant.int 1 - %4711 = torch.aten.slice.Tensor %4708, %int0_5654, %int0_5655, %4710, %int1_5656 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4711, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %int64_5652 = torch.constant.int 64 + %int1_5653 = torch.constant.int 1 + %4715 = torch.aten.slice.Tensor %4654, %int3_5650, %int0_5651, %int64_5652, %int1_5653 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %4715, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_5654 = torch.constant.int 3 + %int64_5655 = torch.constant.int 64 + %int9223372036854775807_5656 = torch.constant.int 9223372036854775807 %int1_5657 = torch.constant.int 1 - %int0_5658 = torch.constant.int 0 - %int9223372036854775807_5659 = torch.constant.int 9223372036854775807 - %int1_5660 = torch.constant.int 1 - %4712 = torch.aten.slice.Tensor %4711, %int1_5657, %int0_5658, %int9223372036854775807_5659, %int1_5660 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4712, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_5661 = torch.constant.int 1 - %int0_5662 = torch.constant.int 0 - %int9223372036854775807_5663 = torch.constant.int 9223372036854775807 - %int1_5664 = torch.constant.int 1 - %4713 = torch.aten.slice.Tensor %4712, %int1_5661, %int0_5662, %int9223372036854775807_5663, %int1_5664 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4713, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %4716 = torch.aten.slice.Tensor %4654, %int3_5654, %int64_5655, %int9223372036854775807_5656, %int1_5657 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %4716, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %4717 = torch.aten.neg %4716 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %4717, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %4718 = torch.prim.ListConstruct %4717, %4715 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_5658 = torch.constant.int -1 + %4719 = torch.aten.cat %4718, %int-1_5658 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4719, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %4720 = torch.aten.mul.Tensor %4719, %4713 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4720, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_5659 = torch.constant.int 1 + %4721 = torch.aten.add.Tensor %4714, %4720, %int1_5659 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4721, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_5660 = torch.constant.int 131072 + %none_5661 = torch.constant.none + %none_5662 = torch.constant.none + %cpu_5663 = torch.constant.device "cpu" + %false_5664 = torch.constant.bool false + %4722 = torch.aten.arange %int131072_5660, %none_5661, %none_5662, %cpu_5663, %false_5664 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> %int0_5665 = torch.constant.int 0 - %4714 = torch.aten.unsqueeze %4713, %int0_5665 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4714, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_5666 = torch.constant.int 1 - %int0_5667 = torch.constant.int 0 - %int9223372036854775807_5668 = torch.constant.int 9223372036854775807 - %int1_5669 = torch.constant.int 1 - %4715 = torch.aten.slice.Tensor %4714, %int1_5666, %int0_5667, %int9223372036854775807_5668, %int1_5669 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4715, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_5670 = torch.constant.int 2 - %int0_5671 = torch.constant.int 0 - %int9223372036854775807_5672 = torch.constant.int 9223372036854775807 - %int1_5673 = torch.constant.int 1 - %4716 = torch.aten.slice.Tensor %4715, %int2_5670, %int0_5671, %int9223372036854775807_5672, %int1_5673 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4716, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_5674 = torch.constant.int 4 - %int1_5675 = torch.constant.int 1 - %int1_5676 = torch.constant.int 1 - %4717 = torch.prim.ListConstruct %int4_5674, %int1_5675, %int1_5676 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4718 = torch.aten.repeat %4716, %4717 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %4718, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_5677 = torch.constant.int 6 - %4719 = torch.prims.convert_element_type %4666, %int6_5677 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %4719, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %4720 = torch_c.to_builtin_tensor %4719 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %4721 = torch_c.to_builtin_tensor %4718 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %4722 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%4720, %4721) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %4723 = torch_c.from_builtin_tensor %4722 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %4723, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_5678 = torch.constant.int 5 - %4724 = torch.prims.convert_element_type %4723, %int5_5678 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4724, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_5679 = torch.constant.int 64 - %4725 = torch.aten.mul.Scalar %arg2, %int64_5679 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4725, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int44 = torch.constant.int 44 + %int128_5666 = torch.constant.int 128 + %int2_5667 = torch.constant.int 2 + %int4_5668 = torch.constant.int 4 + %none_5669 = torch.constant.none + %cpu_5670 = torch.constant.device "cpu" + %false_5671 = torch.constant.bool false + %4723 = torch.aten.arange.start_step %int0_5665, %int128_5666, %int2_5667, %int4_5668, %none_5669, %cpu_5670, %false_5671 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_5672 = torch.constant.int 6 + %4724 = torch.prims.convert_element_type %4723, %int6_5672 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_5673 = torch.constant.int 128 + %4725 = torch.aten.div.Scalar %4724, %int128_5673 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_5674 = torch.constant.float 5.000000e+05 + %4726 = torch.aten.pow.Scalar %float5.000000e05_5674, %4725 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4727 = torch.aten.reciprocal %4726 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_5675 = torch.constant.float 1.000000e+00 + %4728 = torch.aten.mul.Scalar %4727, %float1.000000e00_5675 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %4729 = torch.aten.reciprocal %4728 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_5676 = torch.constant.float 6.2831853071795862 + %4730 = torch.aten.mul.Scalar %4729, %float6.283190e00_5676 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_5677 = torch.constant.float 8.192000e+03 + %4731 = torch.aten.gt.Scalar %4730, %float8.192000e03_5677 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_5678 = torch.constant.int 8 + %4732 = torch.aten.div.Scalar %4728, %int8_5678 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %4733 = torch.aten.where.self %4731, %4732, %4728 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4734 = torch.aten.reciprocal %4730 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_5679 = torch.constant.int 8192 + %4735 = torch.aten.mul.Scalar %4734, %int8192_5679 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_5680 = torch.constant.int 1 - %4726 = torch.aten.add.Scalar %4725, %int44, %int1_5680 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4726, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_5681 = torch.constant.int 4 - %int32_5682 = torch.constant.int 32 - %int8_5683 = torch.constant.int 8 - %int128_5684 = torch.constant.int 128 - %4727 = torch.prim.ListConstruct %int4_5681, %398, %int32_5682, %int8_5683, %int128_5684 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4728 = torch.aten.view %4724, %4727 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4728, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_5685 = torch.constant.int 4 - %4729 = torch.aten.mul.int %int4_5685, %398 : !torch.int, !torch.int -> !torch.int - %int32_5686 = torch.constant.int 32 - %int8_5687 = torch.constant.int 8 - %int128_5688 = torch.constant.int 128 - %4730 = torch.prim.ListConstruct %4729, %int32_5686, %int8_5687, %int128_5688 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4731 = torch.aten.view %4728, %4730 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4731, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_5689 = torch.constant.int 4 - %4732 = torch.aten.mul.int %int4_5689, %398 : !torch.int, !torch.int -> !torch.int - %4733 = torch.prim.ListConstruct %4732 : (!torch.int) -> !torch.list - %4734 = torch.aten.view %4726, %4733 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4734, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_5690 = torch.constant.int 32 - %int2_5691 = torch.constant.int 2 - %int32_5692 = torch.constant.int 32 - %int8_5693 = torch.constant.int 8 - %int128_5694 = torch.constant.int 128 - %4735 = torch.prim.ListConstruct %389, %int32_5690, %int2_5691, %int32_5692, %int8_5693, %int128_5694 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4736 = torch.aten.view %4568, %4735 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4736, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5695 = torch.constant.int 32 - %4737 = torch.aten.mul.int %389, %int32_5695 : !torch.int, !torch.int -> !torch.int - %int2_5696 = torch.constant.int 2 - %4738 = torch.aten.mul.int %4737, %int2_5696 : !torch.int, !torch.int -> !torch.int - %int32_5697 = torch.constant.int 32 - %int8_5698 = torch.constant.int 8 - %int128_5699 = torch.constant.int 128 - %4739 = torch.prim.ListConstruct %4738, %int32_5697, %int8_5698, %int128_5699 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4740 = torch.aten.view %4736, %4739 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4740, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %4741 = torch.prim.ListConstruct %4734 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_5700 = torch.constant.bool false - %4742 = torch.aten.index_put %4740, %4741, %4731, %false_5700 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4742, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_5701 = torch.constant.int 32 - %int2_5702 = torch.constant.int 2 - %int32_5703 = torch.constant.int 32 - %int8_5704 = torch.constant.int 8 - %int128_5705 = torch.constant.int 128 - %4743 = torch.prim.ListConstruct %389, %int32_5701, %int2_5702, %int32_5703, %int8_5704, %int128_5705 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4744 = torch.aten.view %4742, %4743 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4744, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5706 = torch.constant.int 2097152 - %4745 = torch.prim.ListConstruct %389, %int2097152_5706 : (!torch.int, !torch.int) -> !torch.list - %4746 = torch.aten.view %4744, %4745 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4746, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_5707 = torch.constant.int 32 - %int2_5708 = torch.constant.int 2 - %int32_5709 = torch.constant.int 32 - %int8_5710 = torch.constant.int 8 - %int128_5711 = torch.constant.int 128 - %4747 = torch.prim.ListConstruct %389, %int32_5707, %int2_5708, %int32_5709, %int8_5710, %int128_5711 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4748 = torch.aten.view %4746, %4747 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4748, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5712 = torch.constant.int 32 - %int8_5713 = torch.constant.int 8 - %int128_5714 = torch.constant.int 128 - %4749 = torch.prim.ListConstruct %4738, %int32_5712, %int8_5713, %int128_5714 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4750 = torch.aten.view %4748, %4749 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4750, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_5715 = torch.constant.int 4 - %int32_5716 = torch.constant.int 32 - %int8_5717 = torch.constant.int 8 - %int128_5718 = torch.constant.int 128 - %4751 = torch.prim.ListConstruct %int4_5715, %398, %int32_5716, %int8_5717, %int128_5718 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4752 = torch.aten.view %4668, %4751 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4752, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_5719 = torch.constant.int 4 - %4753 = torch.aten.mul.int %int4_5719, %398 : !torch.int, !torch.int -> !torch.int - %int32_5720 = torch.constant.int 32 - %int8_5721 = torch.constant.int 8 - %int128_5722 = torch.constant.int 128 - %4754 = torch.prim.ListConstruct %4753, %int32_5720, %int8_5721, %int128_5722 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4755 = torch.aten.view %4752, %4754 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4755, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_5681 = torch.constant.int 1 + %4736 = torch.aten.sub.Scalar %4735, %int1_5680, %int1_5681 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_5682 = torch.constant.int 3 + %4737 = torch.aten.div.Scalar %4736, %int3_5682 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_5683 = torch.constant.int 1 + %int1_5684 = torch.constant.int 1 + %4738 = torch.aten.rsub.Scalar %4737, %int1_5683, %int1_5684 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %4739 = torch.aten.mul.Tensor %4738, %4733 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_5685 = torch.constant.int 8 + %4740 = torch.aten.div.Scalar %4739, %int8_5685 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %4741 = torch.aten.mul.Tensor %4737, %4733 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_5686 = torch.constant.int 1 + %4742 = torch.aten.add.Tensor %4740, %4741, %int1_5686 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_5687 = torch.constant.float 2.048000e+03 + %4743 = torch.aten.lt.Scalar %4730, %float2.048000e03_5687 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %4744 = torch.aten.bitwise_not %4743 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_5688 = torch.constant.float 8.192000e+03 + %4745 = torch.aten.gt.Scalar %4730, %float8.192000e03_5688 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %4746 = torch.aten.bitwise_not %4745 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %4747 = torch.aten.mul.Tensor %4744, %4746 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %4748 = torch.aten.where.self %4747, %4742, %4733 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4749 = torch.prim.ListConstruct %4748, %4748 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_5689 = torch.constant.int -1 + %4750 = torch.aten.cat %4749, %int-1_5689 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_5690 = torch.constant.int 6 + %4751 = torch.prims.convert_element_type %4750, %int6_5690 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_5691 = torch.constant.int 1 + %4752 = torch.aten.unsqueeze %4722, %int1_5691 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_5692 = torch.constant.int 6 + %4753 = torch.prims.convert_element_type %4752, %int6_5692 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_5693 = torch.constant.int 0 + %4754 = torch.aten.unsqueeze %4751, %int0_5693 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_5694 = torch.constant.int 6 + %4755 = torch.prims.convert_element_type %4754, %int6_5694 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %4756 = torch.aten.mul.Tensor %4753, %4755 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %4757 = torch.aten.cos %4756 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_5695 = torch.constant.int 5 + %4758 = torch.prims.convert_element_type %4757, %int5_5695 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %4759 = torch.aten.sin %4756 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_5696 = torch.constant.int 5 + %4760 = torch.prims.convert_element_type %4759, %int5_5696 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_5697 = torch.constant.int 0 + %int0_5698 = torch.constant.int 0 + %int1_5699 = torch.constant.int 1 + %4761 = torch.aten.slice.Tensor %4758, %int0_5697, %int0_5698, %298, %int1_5699 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4761, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_5700 = torch.constant.int 1 + %int0_5701 = torch.constant.int 0 + %int9223372036854775807_5702 = torch.constant.int 9223372036854775807 + %int1_5703 = torch.constant.int 1 + %4762 = torch.aten.slice.Tensor %4761, %int1_5700, %int0_5701, %int9223372036854775807_5702, %int1_5703 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4762, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_5704 = torch.constant.int 0 + %int0_5705 = torch.constant.int 0 + %int1_5706 = torch.constant.int 1 + %4763 = torch.aten.slice.Tensor %4760, %int0_5704, %int0_5705, %298, %int1_5706 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4763, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_5707 = torch.constant.int 1 + %int0_5708 = torch.constant.int 0 + %int9223372036854775807_5709 = torch.constant.int 9223372036854775807 + %int1_5710 = torch.constant.int 1 + %4764 = torch.aten.slice.Tensor %4763, %int1_5707, %int0_5708, %int9223372036854775807_5709, %int1_5710 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4764, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_5711 = torch.constant.int 0 + %4765 = torch.aten.unsqueeze %4762, %int0_5711 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4765, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_5712 = torch.constant.int 1 + %int0_5713 = torch.constant.int 0 + %int9223372036854775807_5714 = torch.constant.int 9223372036854775807 + %int1_5715 = torch.constant.int 1 + %4766 = torch.aten.slice.Tensor %4765, %int1_5712, %int0_5713, %int9223372036854775807_5714, %int1_5715 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4766, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_5716 = torch.constant.int 2 + %4767 = torch.aten.unsqueeze %4766, %int2_5716 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4767, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_5717 = torch.constant.int 3 + %int0_5718 = torch.constant.int 0 + %int9223372036854775807_5719 = torch.constant.int 9223372036854775807 + %int1_5720 = torch.constant.int 1 + %4768 = torch.aten.slice.Tensor %4767, %int3_5717, %int0_5718, %int9223372036854775807_5719, %int1_5720 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4768, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_5721 = torch.constant.int 4 + %int1_5722 = torch.constant.int 1 %int1_5723 = torch.constant.int 1 %int1_5724 = torch.constant.int 1 - %4756 = torch.aten.add.Scalar %4726, %int1_5723, %int1_5724 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4756, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_5725 = torch.constant.int 4 - %4757 = torch.aten.mul.int %int4_5725, %398 : !torch.int, !torch.int -> !torch.int - %4758 = torch.prim.ListConstruct %4757 : (!torch.int) -> !torch.list - %4759 = torch.aten.view %4756, %4758 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4759, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %4760 = torch.prim.ListConstruct %4759 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_5726 = torch.constant.bool false - %4761 = torch.aten.index_put %4750, %4760, %4755, %false_5726 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4761, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_5727 = torch.constant.int 32 - %int2_5728 = torch.constant.int 2 - %int32_5729 = torch.constant.int 32 - %int8_5730 = torch.constant.int 8 - %int128_5731 = torch.constant.int 128 - %4762 = torch.prim.ListConstruct %389, %int32_5727, %int2_5728, %int32_5729, %int8_5730, %int128_5731 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4763 = torch.aten.view %4761, %4762 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4763, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5732 = torch.constant.int 2097152 - %4764 = torch.prim.ListConstruct %389, %int2097152_5732 : (!torch.int, !torch.int) -> !torch.list - %4765 = torch.aten.view %4763, %4764 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4765, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_5733 = torch.constant.int -2 - %4766 = torch.aten.unsqueeze %4724, %int-2_5733 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4766, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_5734 = torch.constant.int 4 - %int8_5735 = torch.constant.int 8 - %int4_5736 = torch.constant.int 4 - %int128_5737 = torch.constant.int 128 - %4767 = torch.prim.ListConstruct %int4_5734, %4709, %int8_5735, %int4_5736, %int128_5737 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_5738 = torch.constant.bool false - %4768 = torch.aten.expand %4766, %4767, %false_5738 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4768, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_5739 = torch.constant.int 0 - %4769 = torch.aten.clone %4768, %int0_5739 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4769, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_5740 = torch.constant.int 4 - %int32_5741 = torch.constant.int 32 - %int128_5742 = torch.constant.int 128 - %4770 = torch.prim.ListConstruct %int4_5740, %4709, %int32_5741, %int128_5742 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4771 = torch.aten._unsafe_view %4769, %4770 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4771, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_5743 = torch.constant.int -2 - %4772 = torch.aten.unsqueeze %4668, %int-2_5743 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4772, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_5744 = torch.constant.int 1 - %4773 = torch.aten.size.int %4662, %int1_5744 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_5745 = torch.constant.int 4 - %int8_5746 = torch.constant.int 8 - %int4_5747 = torch.constant.int 4 - %int128_5748 = torch.constant.int 128 - %4774 = torch.prim.ListConstruct %int4_5745, %4773, %int8_5746, %int4_5747, %int128_5748 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_5749 = torch.constant.bool false - %4775 = torch.aten.expand %4772, %4774, %false_5749 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4775, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_5750 = torch.constant.int 0 - %4776 = torch.aten.clone %4775, %int0_5750 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4776, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_5751 = torch.constant.int 4 - %int32_5752 = torch.constant.int 32 - %int128_5753 = torch.constant.int 128 - %4777 = torch.prim.ListConstruct %int4_5751, %4773, %int32_5752, %int128_5753 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4778 = torch.aten._unsafe_view %4776, %4777 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4778, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_5754 = torch.constant.int 1 - %int2_5755 = torch.constant.int 2 - %4779 = torch.aten.transpose.int %4696, %int1_5754, %int2_5755 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4779, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_5756 = torch.constant.int 1 - %int2_5757 = torch.constant.int 2 - %4780 = torch.aten.transpose.int %4771, %int1_5756, %int2_5757 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4780, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_5758 = torch.constant.int 1 - %int2_5759 = torch.constant.int 2 - %4781 = torch.aten.transpose.int %4778, %int1_5758, %int2_5759 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4781, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_5760 = torch.constant.float 0.000000e+00 - %true_5761 = torch.constant.bool true - %none_5762 = torch.constant.none - %none_5763 = torch.constant.none - %4782:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4779, %4780, %4781, %float0.000000e00_5760, %true_5761, %none_5762, %none_5763) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %4782#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_5764 = torch.constant.int 1 + %4769 = torch.prim.ListConstruct %int4_5721, %int1_5722, %int1_5723, %int1_5724 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4770 = torch.aten.repeat %4768, %4769 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %4770, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_5725 = torch.constant.int 0 + %4771 = torch.aten.unsqueeze %4764, %int0_5725 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4771, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_5726 = torch.constant.int 1 + %int0_5727 = torch.constant.int 0 + %int9223372036854775807_5728 = torch.constant.int 9223372036854775807 + %int1_5729 = torch.constant.int 1 + %4772 = torch.aten.slice.Tensor %4771, %int1_5726, %int0_5727, %int9223372036854775807_5728, %int1_5729 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4772, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_5730 = torch.constant.int 2 + %4773 = torch.aten.unsqueeze %4772, %int2_5730 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4773, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_5731 = torch.constant.int 3 + %int0_5732 = torch.constant.int 0 + %int9223372036854775807_5733 = torch.constant.int 9223372036854775807 + %int1_5734 = torch.constant.int 1 + %4774 = torch.aten.slice.Tensor %4773, %int3_5731, %int0_5732, %int9223372036854775807_5733, %int1_5734 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4774, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_5735 = torch.constant.int 4 + %int1_5736 = torch.constant.int 1 + %int1_5737 = torch.constant.int 1 + %int1_5738 = torch.constant.int 1 + %4775 = torch.prim.ListConstruct %int4_5735, %int1_5736, %int1_5737, %int1_5738 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4776 = torch.aten.repeat %4774, %4775 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %4776, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %4777 = torch.aten.mul.Tensor %4656, %4770 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4777, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_5739 = torch.constant.int 3 + %int0_5740 = torch.constant.int 0 + %int64_5741 = torch.constant.int 64 + %int1_5742 = torch.constant.int 1 + %4778 = torch.aten.slice.Tensor %4656, %int3_5739, %int0_5740, %int64_5741, %int1_5742 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %4778, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_5743 = torch.constant.int 3 + %int64_5744 = torch.constant.int 64 + %int9223372036854775807_5745 = torch.constant.int 9223372036854775807 + %int1_5746 = torch.constant.int 1 + %4779 = torch.aten.slice.Tensor %4656, %int3_5743, %int64_5744, %int9223372036854775807_5745, %int1_5746 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %4779, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %4780 = torch.aten.neg %4779 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %4780, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %4781 = torch.prim.ListConstruct %4780, %4778 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_5747 = torch.constant.int -1 + %4782 = torch.aten.cat %4781, %int-1_5747 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4782, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %4783 = torch.aten.mul.Tensor %4782, %4776 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4783, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_5748 = torch.constant.int 1 + %4784 = torch.aten.add.Tensor %4777, %4783, %int1_5748 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4784, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_5749 = torch.constant.int 32 + %4785 = torch.aten.mul.Scalar %arg2, %int32_5749 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4785, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int16 = torch.constant.int 16 + %int1_5750 = torch.constant.int 1 + %4786 = torch.aten.add.Scalar %4785, %int16, %int1_5750 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4786, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_5751 = torch.constant.int 2 + %4787 = torch.aten.mul.Scalar %4786, %int2_5751 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4787, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_5752 = torch.constant.int 0 + %int1_5753 = torch.constant.int 1 + %4788 = torch.aten.add.Scalar %4787, %int0_5752, %int1_5753 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4788, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %4789 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %4790 = torch.aten.view %4788, %4789 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %4790, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_5754 = torch.constant.int 4 + %int32_5755 = torch.constant.int 32 + %int8_5756 = torch.constant.int 8 + %int128_5757 = torch.constant.int 128 + %4791 = torch.prim.ListConstruct %int4_5754, %296, %int32_5755, %int8_5756, %int128_5757 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4792 = torch.aten.view %4784, %4791 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4792, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_5758 = torch.constant.int 32 + %int8_5759 = torch.constant.int 8 + %int128_5760 = torch.constant.int 128 + %4793 = torch.prim.ListConstruct %504, %int32_5758, %int8_5759, %int128_5760 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4794 = torch.aten.view %4792, %4793 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %4794, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_5761 = torch.constant.int 1 + %int2_5762 = torch.constant.int 2 + %4795 = torch.aten.transpose.int %4794, %int1_5761, %int2_5762 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4795, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_5763 = torch.constant.int 5 + %4796 = torch.prims.convert_element_type %4795, %int5_5763 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4796, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_5764 = torch.constant.int 32 %int2_5765 = torch.constant.int 2 - %4783 = torch.aten.transpose.int %4782#0, %int1_5764, %int2_5765 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4783, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_5766 = torch.constant.int 4 - %int4096_5767 = torch.constant.int 4096 - %4784 = torch.prim.ListConstruct %int4_5766, %4681, %int4096_5767 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4785 = torch.aten.view %4783, %4784 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4785, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5768 = torch.constant.int -2 - %int-1_5769 = torch.constant.int -1 - %4786 = torch.aten.transpose.int %203, %int-2_5768, %int-1_5769 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_5770 = torch.constant.int 4 - %4787 = torch.aten.mul.int %int4_5770, %4681 : !torch.int, !torch.int -> !torch.int - %int4096_5771 = torch.constant.int 4096 - %4788 = torch.prim.ListConstruct %4787, %int4096_5771 : (!torch.int, !torch.int) -> !torch.list - %4789 = torch.aten.view %4785, %4788 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4789, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4790 = torch.aten.mm %4789, %4786 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4790, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_5772 = torch.constant.int 4 - %int4096_5773 = torch.constant.int 4096 - %4791 = torch.prim.ListConstruct %int4_5772, %4681, %int4096_5773 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4792 = torch.aten.view %4790, %4791 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4792, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_5774 = torch.constant.int 1 - %4793 = torch.aten.add.Tensor %4631, %4792, %int1_5774 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4793, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_5775 = torch.constant.int 6 - %4794 = torch.prims.convert_element_type %4793, %int6_5775 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4794, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_5776 = torch.constant.int 2 - %4795 = torch.aten.pow.Tensor_Scalar %4794, %int2_5776 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4795, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_5777 = torch.constant.int -1 - %4796 = torch.prim.ListConstruct %int-1_5777 : (!torch.int) -> !torch.list - %true_5778 = torch.constant.bool true - %none_5779 = torch.constant.none - %4797 = torch.aten.mean.dim %4795, %4796, %true_5778, %none_5779 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4797, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_5780 = torch.constant.float 9.9999997473787516E-6 - %int1_5781 = torch.constant.int 1 - %4798 = torch.aten.add.Scalar %4797, %float9.999990e-06_5780, %int1_5781 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4798, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4799 = torch.aten.rsqrt %4798 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4799, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4800 = torch.aten.mul.Tensor %4794, %4799 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4800, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_5782 = torch.constant.int 5 - %4801 = torch.prims.convert_element_type %4800, %int5_5782 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4801, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %4802 = torch.aten.mul.Tensor %204, %4801 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4802, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_5783 = torch.constant.int 5 - %4803 = torch.prims.convert_element_type %4802, %int5_5783 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4803, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5784 = torch.constant.int -2 - %int-1_5785 = torch.constant.int -1 - %4804 = torch.aten.transpose.int %205, %int-2_5784, %int-1_5785 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_5786 = torch.constant.int 4 - %4805 = torch.aten.mul.int %int4_5786, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5787 = torch.constant.int 4096 - %4806 = torch.prim.ListConstruct %4805, %int4096_5787 : (!torch.int, !torch.int) -> !torch.list - %4807 = torch.aten.view %4803, %4806 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4807, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4808 = torch.aten.mm %4807, %4804 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4808, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_5788 = torch.constant.int 4 - %int14336_5789 = torch.constant.int 14336 - %4809 = torch.prim.ListConstruct %int4_5788, %306, %int14336_5789 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4810 = torch.aten.view %4808, %4809 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4810, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %4811 = torch.aten.silu %4810 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4811, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_5790 = torch.constant.int -2 - %int-1_5791 = torch.constant.int -1 - %4812 = torch.aten.transpose.int %206, %int-2_5790, %int-1_5791 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_5792 = torch.constant.int 4 - %4813 = torch.aten.mul.int %int4_5792, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5793 = torch.constant.int 4096 - %4814 = torch.prim.ListConstruct %4813, %int4096_5793 : (!torch.int, !torch.int) -> !torch.list - %4815 = torch.aten.view %4803, %4814 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4815, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4816 = torch.aten.mm %4815, %4812 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4816, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_5794 = torch.constant.int 4 - %int14336_5795 = torch.constant.int 14336 - %4817 = torch.prim.ListConstruct %int4_5794, %306, %int14336_5795 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4818 = torch.aten.view %4816, %4817 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4818, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %4819 = torch.aten.mul.Tensor %4811, %4818 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %4819, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_5796 = torch.constant.int -2 - %int-1_5797 = torch.constant.int -1 - %4820 = torch.aten.transpose.int %207, %int-2_5796, %int-1_5797 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_5798 = torch.constant.int 1 - %4821 = torch.aten.size.int %4810, %int1_5798 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_5799 = torch.constant.int 4 - %4822 = torch.aten.mul.int %int4_5799, %4821 : !torch.int, !torch.int -> !torch.int - %int14336_5800 = torch.constant.int 14336 - %4823 = torch.prim.ListConstruct %4822, %int14336_5800 : (!torch.int, !torch.int) -> !torch.list - %4824 = torch.aten.view %4819, %4823 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %4824, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %4825 = torch.aten.mm %4824, %4820 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4825, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_5801 = torch.constant.int 4 - %int4096_5802 = torch.constant.int 4096 - %4826 = torch.prim.ListConstruct %int4_5801, %4821, %int4096_5802 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4827 = torch.aten.view %4825, %4826 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4827, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_5803 = torch.constant.int 1 - %4828 = torch.aten.add.Tensor %4793, %4827, %int1_5803 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4828, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_5804 = torch.constant.int 6 - %4829 = torch.prims.convert_element_type %4828, %int6_5804 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4829, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int8_5766 = torch.constant.int 8 + %int32_5767 = torch.constant.int 32 + %int128_5768 = torch.constant.int 128 + %4797 = torch.prim.ListConstruct %297, %int32_5764, %int2_5765, %int8_5766, %int32_5767, %int128_5768 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4798 = torch.aten.view %4560, %4797 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4798, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_5769 = torch.constant.int 8 + %int32_5770 = torch.constant.int 32 + %int128_5771 = torch.constant.int 128 + %4799 = torch.prim.ListConstruct %497, %int8_5769, %int32_5770, %int128_5771 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4800 = torch.aten.view %4798, %4799 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4800, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %4801 = torch.prim.ListConstruct %4790 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_5772 = torch.constant.bool false + %4802 = torch.aten.index_put %4800, %4801, %4796, %false_5772 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4802, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_5773 = torch.constant.int 32 + %int2_5774 = torch.constant.int 2 + %int8_5775 = torch.constant.int 8 + %int32_5776 = torch.constant.int 32 + %int128_5777 = torch.constant.int 128 + %4803 = torch.prim.ListConstruct %297, %int32_5773, %int2_5774, %int8_5775, %int32_5776, %int128_5777 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4804 = torch.aten.view %4802, %4803 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4804, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_5778 = torch.constant.int 2097152 + %4805 = torch.prim.ListConstruct %297, %int2097152_5778 : (!torch.int, !torch.int) -> !torch.list + %4806 = torch.aten.view %4804, %4805 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4806, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_5779 = torch.constant.int 32 + %int2_5780 = torch.constant.int 2 + %int8_5781 = torch.constant.int 8 + %int32_5782 = torch.constant.int 32 + %int128_5783 = torch.constant.int 128 + %4807 = torch.prim.ListConstruct %297, %int32_5779, %int2_5780, %int8_5781, %int32_5782, %int128_5783 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4808 = torch.aten.view %4806, %4807 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4808, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_5784 = torch.constant.int 8 + %int32_5785 = torch.constant.int 32 + %int128_5786 = torch.constant.int 128 + %4809 = torch.prim.ListConstruct %497, %int8_5784, %int32_5785, %int128_5786 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4810 = torch.aten.view %4808, %4809 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4810, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_5787 = torch.constant.int 32 + %4811 = torch.aten.mul.Scalar %arg2, %int32_5787 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4811, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int16_5788 = torch.constant.int 16 + %int1_5789 = torch.constant.int 1 + %4812 = torch.aten.add.Scalar %4811, %int16_5788, %int1_5789 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4812, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_5790 = torch.constant.int 2 + %4813 = torch.aten.mul.Scalar %4812, %int2_5790 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4813, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_5791 = torch.constant.int 1 + %int1_5792 = torch.constant.int 1 + %4814 = torch.aten.add.Scalar %4813, %int1_5791, %int1_5792 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %4814, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %4815 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %4816 = torch.aten.view %4814, %4815 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %4816, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_5793 = torch.constant.int 4 + %int32_5794 = torch.constant.int 32 + %int8_5795 = torch.constant.int 8 + %int128_5796 = torch.constant.int 128 + %4817 = torch.prim.ListConstruct %int4_5793, %296, %int32_5794, %int8_5795, %int128_5796 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4818 = torch.aten.view %4658, %4817 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4818, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_5797 = torch.constant.int 32 + %int8_5798 = torch.constant.int 8 + %int128_5799 = torch.constant.int 128 + %4819 = torch.prim.ListConstruct %504, %int32_5797, %int8_5798, %int128_5799 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4820 = torch.aten.view %4818, %4819 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %4820, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_5800 = torch.constant.int 1 + %int2_5801 = torch.constant.int 2 + %4821 = torch.aten.transpose.int %4820, %int1_5800, %int2_5801 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4821, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_5802 = torch.constant.int 5 + %4822 = torch.prims.convert_element_type %4821, %int5_5802 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4822, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %4823 = torch.prim.ListConstruct %4816 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_5803 = torch.constant.bool false + %4824 = torch.aten.index_put %4810, %4823, %4822, %false_5803 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %4824, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_5804 = torch.constant.int 32 %int2_5805 = torch.constant.int 2 - %4830 = torch.aten.pow.Tensor_Scalar %4829, %int2_5805 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4830, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_5806 = torch.constant.int -1 - %4831 = torch.prim.ListConstruct %int-1_5806 : (!torch.int) -> !torch.list - %true_5807 = torch.constant.bool true - %none_5808 = torch.constant.none - %4832 = torch.aten.mean.dim %4830, %4831, %true_5807, %none_5808 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4832, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_5809 = torch.constant.float 9.9999997473787516E-6 - %int1_5810 = torch.constant.int 1 - %4833 = torch.aten.add.Scalar %4832, %float9.999990e-06_5809, %int1_5810 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4833, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4834 = torch.aten.rsqrt %4833 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4834, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4835 = torch.aten.mul.Tensor %4829, %4834 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4835, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_5811 = torch.constant.int 5 - %4836 = torch.prims.convert_element_type %4835, %int5_5811 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4836, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %4837 = torch.aten.mul.Tensor %208, %4836 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4837, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_5812 = torch.constant.int 5 - %4838 = torch.prims.convert_element_type %4837, %int5_5812 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4838, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5813 = torch.constant.int -2 - %int-1_5814 = torch.constant.int -1 - %4839 = torch.aten.transpose.int %209, %int-2_5813, %int-1_5814 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_5815 = torch.constant.int 4 - %4840 = torch.aten.mul.int %int4_5815, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5816 = torch.constant.int 4096 - %4841 = torch.prim.ListConstruct %4840, %int4096_5816 : (!torch.int, !torch.int) -> !torch.list - %4842 = torch.aten.view %4838, %4841 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4842, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4843 = torch.aten.mm %4842, %4839 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4843, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int8_5806 = torch.constant.int 8 + %int32_5807 = torch.constant.int 32 + %int128_5808 = torch.constant.int 128 + %4825 = torch.prim.ListConstruct %297, %int32_5804, %int2_5805, %int8_5806, %int32_5807, %int128_5808 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4826 = torch.aten.view %4824, %4825 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4826, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_5809 = torch.constant.int 2097152 + %4827 = torch.prim.ListConstruct %297, %int2097152_5809 : (!torch.int, !torch.int) -> !torch.list + %4828 = torch.aten.view %4826, %4827 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4828, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_5810 = torch.constant.int -2 + %4829 = torch.aten.unsqueeze %4784, %int-2_5810 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4829, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_5811 = torch.constant.int 4 + %int8_5812 = torch.constant.int 8 + %int4_5813 = torch.constant.int 4 + %int128_5814 = torch.constant.int 128 + %4830 = torch.prim.ListConstruct %int4_5811, %298, %int8_5812, %int4_5813, %int128_5814 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_5815 = torch.constant.bool false + %4831 = torch.aten.expand %4829, %4830, %false_5815 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4831, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_5816 = torch.constant.int 0 + %4832 = torch.aten.clone %4831, %int0_5816 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4832, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_5817 = torch.constant.int 4 - %int4096_5818 = torch.constant.int 4096 - %4844 = torch.prim.ListConstruct %int4_5817, %306, %int4096_5818 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4845 = torch.aten.view %4843, %4844 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4845, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_5819 = torch.constant.int -2 - %int-1_5820 = torch.constant.int -1 - %4846 = torch.aten.transpose.int %210, %int-2_5819, %int-1_5820 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int32_5818 = torch.constant.int 32 + %int128_5819 = torch.constant.int 128 + %4833 = torch.prim.ListConstruct %int4_5817, %298, %int32_5818, %int128_5819 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4834 = torch.aten._unsafe_view %4832, %4833 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4834, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_5820 = torch.constant.int -2 + %4835 = torch.aten.unsqueeze %4658, %int-2_5820 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4835, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> %int4_5821 = torch.constant.int 4 - %4847 = torch.aten.mul.int %int4_5821, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5822 = torch.constant.int 4096 - %4848 = torch.prim.ListConstruct %4847, %int4096_5822 : (!torch.int, !torch.int) -> !torch.list - %4849 = torch.aten.view %4838, %4848 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4849, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4850 = torch.aten.mm %4849, %4846 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %4850, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int8_5822 = torch.constant.int 8 %int4_5823 = torch.constant.int 4 - %int1024_5824 = torch.constant.int 1024 - %4851 = torch.prim.ListConstruct %int4_5823, %306, %int1024_5824 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4852 = torch.aten.view %4850, %4851 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %4852, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_5825 = torch.constant.int -2 - %int-1_5826 = torch.constant.int -1 - %4853 = torch.aten.transpose.int %211, %int-2_5825, %int-1_5826 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int128_5824 = torch.constant.int 128 + %4836 = torch.prim.ListConstruct %int4_5821, %298, %int8_5822, %int4_5823, %int128_5824 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_5825 = torch.constant.bool false + %4837 = torch.aten.expand %4835, %4836, %false_5825 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4837, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_5826 = torch.constant.int 0 + %4838 = torch.aten.clone %4837, %int0_5826 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4838, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_5827 = torch.constant.int 4 - %4854 = torch.aten.mul.int %int4_5827, %306 : !torch.int, !torch.int -> !torch.int - %int4096_5828 = torch.constant.int 4096 - %4855 = torch.prim.ListConstruct %4854, %int4096_5828 : (!torch.int, !torch.int) -> !torch.list - %4856 = torch.aten.view %4838, %4855 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4856, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4857 = torch.aten.mm %4856, %4853 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %4857, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_5829 = torch.constant.int 4 - %int1024_5830 = torch.constant.int 1024 - %4858 = torch.prim.ListConstruct %int4_5829, %306, %int1024_5830 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4859 = torch.aten.view %4857, %4858 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %4859, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_5831 = torch.constant.int 4 - %int32_5832 = torch.constant.int 32 - %int128_5833 = torch.constant.int 128 - %4860 = torch.prim.ListConstruct %int4_5831, %306, %int32_5832, %int128_5833 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4861 = torch.aten.view %4845, %4860 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4861, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_5834 = torch.constant.int 4 - %int8_5835 = torch.constant.int 8 - %int128_5836 = torch.constant.int 128 - %4862 = torch.prim.ListConstruct %int4_5834, %306, %int8_5835, %int128_5836 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4863 = torch.aten.view %4852, %4862 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4863, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_5837 = torch.constant.int 4 - %int8_5838 = torch.constant.int 8 - %int128_5839 = torch.constant.int 128 - %4864 = torch.prim.ListConstruct %int4_5837, %306, %int8_5838, %int128_5839 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4865 = torch.aten.view %4859, %4864 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4865, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_5840 = torch.constant.int 131072 - %none_5841 = torch.constant.none - %none_5842 = torch.constant.none - %cpu_5843 = torch.constant.device "cpu" - %false_5844 = torch.constant.bool false - %4866 = torch.aten.arange %int131072_5840, %none_5841, %none_5842, %cpu_5843, %false_5844 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_5845 = torch.constant.int 0 - %int128_5846 = torch.constant.int 128 - %none_5847 = torch.constant.none - %none_5848 = torch.constant.none - %cpu_5849 = torch.constant.device "cpu" - %false_5850 = torch.constant.bool false - %4867 = torch.aten.arange.start %int0_5845, %int128_5846, %none_5847, %none_5848, %cpu_5849, %false_5850 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> + %int32_5828 = torch.constant.int 32 + %int128_5829 = torch.constant.int 128 + %4839 = torch.prim.ListConstruct %int4_5827, %298, %int32_5828, %int128_5829 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4840 = torch.aten._unsafe_view %4838, %4839 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4840, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_5830 = torch.constant.int 1 + %int2_5831 = torch.constant.int 2 + %4841 = torch.aten.transpose.int %4721, %int1_5830, %int2_5831 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4841, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_5832 = torch.constant.int 1 + %int2_5833 = torch.constant.int 2 + %4842 = torch.aten.transpose.int %4834, %int1_5832, %int2_5833 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4842, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_5834 = torch.constant.int 1 + %int2_5835 = torch.constant.int 2 + %4843 = torch.aten.transpose.int %4840, %int1_5834, %int2_5835 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4843, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_5836 = torch.constant.float 0.000000e+00 + %false_5837 = torch.constant.bool false + %none_5838 = torch.constant.none + %4844:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4841, %4842, %4843, %float0.000000e00_5836, %false_5837, %327, %none_5838) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %4844#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_5839 = torch.constant.int 1 + %int2_5840 = torch.constant.int 2 + %4845 = torch.aten.transpose.int %4844#0, %int1_5839, %int2_5840 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4845, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_5841 = torch.constant.int 4 + %int4096_5842 = torch.constant.int 4096 + %4846 = torch.prim.ListConstruct %int4_5841, %298, %int4096_5842 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4847 = torch.aten.view %4845, %4846 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4847, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_5843 = torch.constant.int -2 + %int-1_5844 = torch.constant.int -1 + %4848 = torch.aten.transpose.int %150, %int-2_5843, %int-1_5844 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_5845 = torch.constant.int 5 + %4849 = torch.prims.convert_element_type %4848, %int5_5845 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_5846 = torch.constant.int 4096 + %4850 = torch.prim.ListConstruct %342, %int4096_5846 : (!torch.int, !torch.int) -> !torch.list + %4851 = torch.aten.view %4847, %4850 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4851, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4852 = torch.aten.mm %4851, %4849 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4852, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_5847 = torch.constant.int 4 + %int4096_5848 = torch.constant.int 4096 + %4853 = torch.prim.ListConstruct %int4_5847, %298, %int4096_5848 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4854 = torch.aten.view %4852, %4853 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4854, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_5849 = torch.constant.int 1 + %4855 = torch.aten.add.Tensor %4621, %4854, %int1_5849 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4855, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_5850 = torch.constant.int 6 + %4856 = torch.prims.convert_element_type %4855, %int6_5850 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4856, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> %int2_5851 = torch.constant.int 2 - %4868 = torch.aten.floor_divide.Scalar %4867, %int2_5851 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_5852 = torch.constant.int 6 - %4869 = torch.prims.convert_element_type %4868, %int6_5852 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_5853 = torch.constant.int 128 - %4870 = torch.aten.div.Scalar %4869, %int128_5853 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_5854 = torch.constant.float 2.000000e+00 - %4871 = torch.aten.mul.Scalar %4870, %float2.000000e00_5854 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_5855 = torch.constant.float 5.000000e+05 - %4872 = torch.aten.pow.Scalar %float5.000000e05_5855, %4871 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %4873 = torch.aten.reciprocal %4872 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_5856 = torch.constant.float 1.000000e+00 - %4874 = torch.aten.mul.Scalar %4873, %float1.000000e00_5856 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_5857 = torch.constant.int 1 - %4875 = torch.aten.unsqueeze %4866, %int1_5857 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_5858 = torch.constant.int 0 - %4876 = torch.aten.unsqueeze %4874, %int0_5858 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %4877 = torch.aten.mul.Tensor %4875, %4876 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_5859 = torch.constant.int 1 - %4878 = torch.aten.size.int %4845, %int1_5859 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_5860 = torch.constant.int 0 - %4879 = torch.aten.add.int %int0_5860, %4878 : !torch.int, !torch.int -> !torch.int - %int0_5861 = torch.constant.int 0 - %int0_5862 = torch.constant.int 0 - %int1_5863 = torch.constant.int 1 - %4880 = torch.aten.slice.Tensor %4877, %int0_5861, %int0_5862, %4879, %int1_5863 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4880, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_5864 = torch.constant.int 1 - %int0_5865 = torch.constant.int 0 - %int9223372036854775807_5866 = torch.constant.int 9223372036854775807 - %int1_5867 = torch.constant.int 1 - %4881 = torch.aten.slice.Tensor %4880, %int1_5864, %int0_5865, %int9223372036854775807_5866, %int1_5867 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4881, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_5868 = torch.constant.int 1 - %int0_5869 = torch.constant.int 0 - %int9223372036854775807_5870 = torch.constant.int 9223372036854775807 - %int1_5871 = torch.constant.int 1 - %4882 = torch.aten.slice.Tensor %4881, %int1_5868, %int0_5869, %int9223372036854775807_5870, %int1_5871 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4882, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_5872 = torch.constant.int 0 - %4883 = torch.aten.unsqueeze %4882, %int0_5872 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4883, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_5873 = torch.constant.int 1 - %int0_5874 = torch.constant.int 0 - %int9223372036854775807_5875 = torch.constant.int 9223372036854775807 - %int1_5876 = torch.constant.int 1 - %4884 = torch.aten.slice.Tensor %4883, %int1_5873, %int0_5874, %int9223372036854775807_5875, %int1_5876 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4884, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_5877 = torch.constant.int 2 - %int0_5878 = torch.constant.int 0 - %int9223372036854775807_5879 = torch.constant.int 9223372036854775807 - %int1_5880 = torch.constant.int 1 - %4885 = torch.aten.slice.Tensor %4884, %int2_5877, %int0_5878, %int9223372036854775807_5879, %int1_5880 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4885, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_5881 = torch.constant.int 4 - %int1_5882 = torch.constant.int 1 - %int1_5883 = torch.constant.int 1 - %4886 = torch.prim.ListConstruct %int4_5881, %int1_5882, %int1_5883 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4887 = torch.aten.repeat %4885, %4886 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %4887, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_5884 = torch.constant.int 6 - %4888 = torch.prims.convert_element_type %4861, %int6_5884 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %4888, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %4889 = torch_c.to_builtin_tensor %4888 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %4890 = torch_c.to_builtin_tensor %4887 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %4891 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%4889, %4890) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %4892 = torch_c.from_builtin_tensor %4891 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %4892, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> + %4857 = torch.aten.pow.Tensor_Scalar %4856, %int2_5851 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4857, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_5852 = torch.constant.int -1 + %4858 = torch.prim.ListConstruct %int-1_5852 : (!torch.int) -> !torch.list + %true_5853 = torch.constant.bool true + %none_5854 = torch.constant.none + %4859 = torch.aten.mean.dim %4857, %4858, %true_5853, %none_5854 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4859, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_5855 = torch.constant.float 9.9999997473787516E-6 + %int1_5856 = torch.constant.int 1 + %4860 = torch.aten.add.Scalar %4859, %float9.999990e-06_5855, %int1_5856 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4860, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4861 = torch.aten.rsqrt %4860 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4861, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4862 = torch.aten.mul.Tensor %4856, %4861 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4862, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_5857 = torch.constant.int 5 + %4863 = torch.prims.convert_element_type %4862, %int5_5857 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4863, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %4864 = torch.aten.mul.Tensor %151, %4863 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4864, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_5858 = torch.constant.int 5 + %4865 = torch.prims.convert_element_type %4864, %int5_5858 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4865, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_5859 = torch.constant.int -2 + %int-1_5860 = torch.constant.int -1 + %4866 = torch.aten.transpose.int %152, %int-2_5859, %int-1_5860 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_5861 = torch.constant.int 5 + %4867 = torch.prims.convert_element_type %4866, %int5_5861 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_5862 = torch.constant.int 4096 + %4868 = torch.prim.ListConstruct %342, %int4096_5862 : (!torch.int, !torch.int) -> !torch.list + %4869 = torch.aten.view %4865, %4868 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4869, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4870 = torch.aten.mm %4869, %4867 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %4870, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_5863 = torch.constant.int 4 + %int14336_5864 = torch.constant.int 14336 + %4871 = torch.prim.ListConstruct %int4_5863, %298, %int14336_5864 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4872 = torch.aten.view %4870, %4871 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4872, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %4873 = torch.aten.silu %4872 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4873, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_5865 = torch.constant.int -2 + %int-1_5866 = torch.constant.int -1 + %4874 = torch.aten.transpose.int %153, %int-2_5865, %int-1_5866 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_5867 = torch.constant.int 5 + %4875 = torch.prims.convert_element_type %4874, %int5_5867 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_5868 = torch.constant.int 4096 + %4876 = torch.prim.ListConstruct %342, %int4096_5868 : (!torch.int, !torch.int) -> !torch.list + %4877 = torch.aten.view %4865, %4876 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4877, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4878 = torch.aten.mm %4877, %4875 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %4878, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_5869 = torch.constant.int 4 + %int14336_5870 = torch.constant.int 14336 + %4879 = torch.prim.ListConstruct %int4_5869, %298, %int14336_5870 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4880 = torch.aten.view %4878, %4879 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4880, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %4881 = torch.aten.mul.Tensor %4873, %4880 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %4881, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_5871 = torch.constant.int -2 + %int-1_5872 = torch.constant.int -1 + %4882 = torch.aten.transpose.int %154, %int-2_5871, %int-1_5872 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_5873 = torch.constant.int 5 + %4883 = torch.prims.convert_element_type %4882, %int5_5873 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_5874 = torch.constant.int 14336 + %4884 = torch.prim.ListConstruct %342, %int14336_5874 : (!torch.int, !torch.int) -> !torch.list + %4885 = torch.aten.view %4881, %4884 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %4885, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %4886 = torch.aten.mm %4885, %4883 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4886, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_5875 = torch.constant.int 4 + %int4096_5876 = torch.constant.int 4096 + %4887 = torch.prim.ListConstruct %int4_5875, %298, %int4096_5876 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4888 = torch.aten.view %4886, %4887 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4888, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_5877 = torch.constant.int 1 + %4889 = torch.aten.add.Tensor %4855, %4888, %int1_5877 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4889, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_5878 = torch.constant.int 6 + %4890 = torch.prims.convert_element_type %4889, %int6_5878 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4890, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_5879 = torch.constant.int 2 + %4891 = torch.aten.pow.Tensor_Scalar %4890, %int2_5879 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4891, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_5880 = torch.constant.int -1 + %4892 = torch.prim.ListConstruct %int-1_5880 : (!torch.int) -> !torch.list + %true_5881 = torch.constant.bool true + %none_5882 = torch.constant.none + %4893 = torch.aten.mean.dim %4891, %4892, %true_5881, %none_5882 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4893, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_5883 = torch.constant.float 9.9999997473787516E-6 + %int1_5884 = torch.constant.int 1 + %4894 = torch.aten.add.Scalar %4893, %float9.999990e-06_5883, %int1_5884 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4894, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4895 = torch.aten.rsqrt %4894 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %4895, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %4896 = torch.aten.mul.Tensor %4890, %4895 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4896, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> %int5_5885 = torch.constant.int 5 - %4893 = torch.prims.convert_element_type %4892, %int5_5885 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4893, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_5886 = torch.constant.int 131072 - %none_5887 = torch.constant.none - %none_5888 = torch.constant.none - %cpu_5889 = torch.constant.device "cpu" - %false_5890 = torch.constant.bool false - %4894 = torch.aten.arange %int131072_5886, %none_5887, %none_5888, %cpu_5889, %false_5890 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_5891 = torch.constant.int 0 - %int128_5892 = torch.constant.int 128 - %none_5893 = torch.constant.none - %none_5894 = torch.constant.none - %cpu_5895 = torch.constant.device "cpu" - %false_5896 = torch.constant.bool false - %4895 = torch.aten.arange.start %int0_5891, %int128_5892, %none_5893, %none_5894, %cpu_5895, %false_5896 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_5897 = torch.constant.int 2 - %4896 = torch.aten.floor_divide.Scalar %4895, %int2_5897 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_5898 = torch.constant.int 6 - %4897 = torch.prims.convert_element_type %4896, %int6_5898 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_5899 = torch.constant.int 128 - %4898 = torch.aten.div.Scalar %4897, %int128_5899 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_5900 = torch.constant.float 2.000000e+00 - %4899 = torch.aten.mul.Scalar %4898, %float2.000000e00_5900 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_5901 = torch.constant.float 5.000000e+05 - %4900 = torch.aten.pow.Scalar %float5.000000e05_5901, %4899 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %4901 = torch.aten.reciprocal %4900 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_5902 = torch.constant.float 1.000000e+00 - %4902 = torch.aten.mul.Scalar %4901, %float1.000000e00_5902 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_5903 = torch.constant.int 1 - %4903 = torch.aten.unsqueeze %4894, %int1_5903 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_5904 = torch.constant.int 0 - %4904 = torch.aten.unsqueeze %4902, %int0_5904 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %4905 = torch.aten.mul.Tensor %4903, %4904 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_5905 = torch.constant.int 1 - %4906 = torch.aten.size.int %4852, %int1_5905 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_5906 = torch.constant.int 0 - %4907 = torch.aten.add.int %int0_5906, %4906 : !torch.int, !torch.int -> !torch.int - %int0_5907 = torch.constant.int 0 - %int0_5908 = torch.constant.int 0 - %int1_5909 = torch.constant.int 1 - %4908 = torch.aten.slice.Tensor %4905, %int0_5907, %int0_5908, %4907, %int1_5909 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4908, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_5910 = torch.constant.int 1 - %int0_5911 = torch.constant.int 0 - %int9223372036854775807_5912 = torch.constant.int 9223372036854775807 - %int1_5913 = torch.constant.int 1 - %4909 = torch.aten.slice.Tensor %4908, %int1_5910, %int0_5911, %int9223372036854775807_5912, %int1_5913 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4909, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_5914 = torch.constant.int 1 - %int0_5915 = torch.constant.int 0 - %int9223372036854775807_5916 = torch.constant.int 9223372036854775807 - %int1_5917 = torch.constant.int 1 - %4910 = torch.aten.slice.Tensor %4909, %int1_5914, %int0_5915, %int9223372036854775807_5916, %int1_5917 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %4910, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_5918 = torch.constant.int 0 - %4911 = torch.aten.unsqueeze %4910, %int0_5918 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4911, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_5919 = torch.constant.int 1 - %int0_5920 = torch.constant.int 0 - %int9223372036854775807_5921 = torch.constant.int 9223372036854775807 - %int1_5922 = torch.constant.int 1 - %4912 = torch.aten.slice.Tensor %4911, %int1_5919, %int0_5920, %int9223372036854775807_5921, %int1_5922 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4912, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_5923 = torch.constant.int 2 - %int0_5924 = torch.constant.int 0 - %int9223372036854775807_5925 = torch.constant.int 9223372036854775807 - %int1_5926 = torch.constant.int 1 - %4913 = torch.aten.slice.Tensor %4912, %int2_5923, %int0_5924, %int9223372036854775807_5925, %int1_5926 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %4913, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_5927 = torch.constant.int 4 - %int1_5928 = torch.constant.int 1 - %int1_5929 = torch.constant.int 1 - %4914 = torch.prim.ListConstruct %int4_5927, %int1_5928, %int1_5929 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4915 = torch.aten.repeat %4913, %4914 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %4915, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_5930 = torch.constant.int 6 - %4916 = torch.prims.convert_element_type %4863, %int6_5930 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %4916, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %4917 = torch_c.to_builtin_tensor %4916 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %4918 = torch_c.to_builtin_tensor %4915 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %4919 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%4917, %4918) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %4920 = torch_c.from_builtin_tensor %4919 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %4920, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_5931 = torch.constant.int 5 - %4921 = torch.prims.convert_element_type %4920, %int5_5931 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4921, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_5932 = torch.constant.int 64 - %4922 = torch.aten.mul.Scalar %arg2, %int64_5932 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4922, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int46 = torch.constant.int 46 - %int1_5933 = torch.constant.int 1 - %4923 = torch.aten.add.Scalar %4922, %int46, %int1_5933 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4923, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_5934 = torch.constant.int 4 - %int32_5935 = torch.constant.int 32 - %int8_5936 = torch.constant.int 8 - %int128_5937 = torch.constant.int 128 - %4924 = torch.prim.ListConstruct %int4_5934, %398, %int32_5935, %int8_5936, %int128_5937 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4925 = torch.aten.view %4921, %4924 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4925, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_5938 = torch.constant.int 4 - %4926 = torch.aten.mul.int %int4_5938, %398 : !torch.int, !torch.int -> !torch.int - %int32_5939 = torch.constant.int 32 - %int8_5940 = torch.constant.int 8 - %int128_5941 = torch.constant.int 128 - %4927 = torch.prim.ListConstruct %4926, %int32_5939, %int8_5940, %int128_5941 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4928 = torch.aten.view %4925, %4927 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4928, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_5942 = torch.constant.int 4 - %4929 = torch.aten.mul.int %int4_5942, %398 : !torch.int, !torch.int -> !torch.int - %4930 = torch.prim.ListConstruct %4929 : (!torch.int) -> !torch.list - %4931 = torch.aten.view %4923, %4930 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4931, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_5943 = torch.constant.int 32 - %int2_5944 = torch.constant.int 2 - %int32_5945 = torch.constant.int 32 - %int8_5946 = torch.constant.int 8 - %int128_5947 = torch.constant.int 128 - %4932 = torch.prim.ListConstruct %389, %int32_5943, %int2_5944, %int32_5945, %int8_5946, %int128_5947 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4933 = torch.aten.view %4765, %4932 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4933, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5948 = torch.constant.int 32 - %4934 = torch.aten.mul.int %389, %int32_5948 : !torch.int, !torch.int -> !torch.int - %int2_5949 = torch.constant.int 2 - %4935 = torch.aten.mul.int %4934, %int2_5949 : !torch.int, !torch.int -> !torch.int - %int32_5950 = torch.constant.int 32 - %int8_5951 = torch.constant.int 8 - %int128_5952 = torch.constant.int 128 - %4936 = torch.prim.ListConstruct %4935, %int32_5950, %int8_5951, %int128_5952 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4937 = torch.aten.view %4933, %4936 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4937, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %4938 = torch.prim.ListConstruct %4931 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_5953 = torch.constant.bool false - %4939 = torch.aten.index_put %4937, %4938, %4928, %false_5953 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4939, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_5954 = torch.constant.int 32 - %int2_5955 = torch.constant.int 2 - %int32_5956 = torch.constant.int 32 - %int8_5957 = torch.constant.int 8 - %int128_5958 = torch.constant.int 128 - %4940 = torch.prim.ListConstruct %389, %int32_5954, %int2_5955, %int32_5956, %int8_5957, %int128_5958 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4941 = torch.aten.view %4939, %4940 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4941, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5959 = torch.constant.int 2097152 - %4942 = torch.prim.ListConstruct %389, %int2097152_5959 : (!torch.int, !torch.int) -> !torch.list - %4943 = torch.aten.view %4941, %4942 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4943, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_5960 = torch.constant.int 32 - %int2_5961 = torch.constant.int 2 - %int32_5962 = torch.constant.int 32 - %int8_5963 = torch.constant.int 8 - %int128_5964 = torch.constant.int 128 - %4944 = torch.prim.ListConstruct %389, %int32_5960, %int2_5961, %int32_5962, %int8_5963, %int128_5964 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4945 = torch.aten.view %4943, %4944 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4945, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5965 = torch.constant.int 32 - %int8_5966 = torch.constant.int 8 - %int128_5967 = torch.constant.int 128 - %4946 = torch.prim.ListConstruct %4935, %int32_5965, %int8_5966, %int128_5967 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4947 = torch.aten.view %4945, %4946 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4947, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_5968 = torch.constant.int 4 - %int32_5969 = torch.constant.int 32 - %int8_5970 = torch.constant.int 8 - %int128_5971 = torch.constant.int 128 - %4948 = torch.prim.ListConstruct %int4_5968, %398, %int32_5969, %int8_5970, %int128_5971 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4949 = torch.aten.view %4865, %4948 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4949, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_5972 = torch.constant.int 4 - %4950 = torch.aten.mul.int %int4_5972, %398 : !torch.int, !torch.int -> !torch.int - %int32_5973 = torch.constant.int 32 - %int8_5974 = torch.constant.int 8 - %int128_5975 = torch.constant.int 128 - %4951 = torch.prim.ListConstruct %4950, %int32_5973, %int8_5974, %int128_5975 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4952 = torch.aten.view %4949, %4951 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4952, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %4897 = torch.prims.convert_element_type %4896, %int5_5885 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4897, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %4898 = torch.aten.mul.Tensor %155, %4897 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %4898, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_5886 = torch.constant.int 5 + %4899 = torch.prims.convert_element_type %4898, %int5_5886 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4899, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_5887 = torch.constant.int -2 + %int-1_5888 = torch.constant.int -1 + %4900 = torch.aten.transpose.int %156, %int-2_5887, %int-1_5888 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_5889 = torch.constant.int 5 + %4901 = torch.prims.convert_element_type %4900, %int5_5889 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_5890 = torch.constant.int 4096 + %4902 = torch.prim.ListConstruct %342, %int4096_5890 : (!torch.int, !torch.int) -> !torch.list + %4903 = torch.aten.view %4899, %4902 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4903, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4904 = torch.aten.mm %4903, %4901 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4904, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_5891 = torch.constant.int 4 + %int4096_5892 = torch.constant.int 4096 + %4905 = torch.prim.ListConstruct %int4_5891, %298, %int4096_5892 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4906 = torch.aten.view %4904, %4905 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %4906, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_5893 = torch.constant.int -2 + %int-1_5894 = torch.constant.int -1 + %4907 = torch.aten.transpose.int %157, %int-2_5893, %int-1_5894 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_5895 = torch.constant.int 5 + %4908 = torch.prims.convert_element_type %4907, %int5_5895 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_5896 = torch.constant.int 4096 + %4909 = torch.prim.ListConstruct %342, %int4096_5896 : (!torch.int, !torch.int) -> !torch.list + %4910 = torch.aten.view %4899, %4909 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4910, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4911 = torch.aten.mm %4910, %4908 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %4911, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_5897 = torch.constant.int 4 + %int1024_5898 = torch.constant.int 1024 + %4912 = torch.prim.ListConstruct %int4_5897, %298, %int1024_5898 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4913 = torch.aten.view %4911, %4912 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %4913, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_5899 = torch.constant.int -2 + %int-1_5900 = torch.constant.int -1 + %4914 = torch.aten.transpose.int %158, %int-2_5899, %int-1_5900 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_5901 = torch.constant.int 5 + %4915 = torch.prims.convert_element_type %4914, %int5_5901 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_5902 = torch.constant.int 4096 + %4916 = torch.prim.ListConstruct %342, %int4096_5902 : (!torch.int, !torch.int) -> !torch.list + %4917 = torch.aten.view %4899, %4916 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %4917, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %4918 = torch.aten.mm %4917, %4915 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %4918, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_5903 = torch.constant.int 4 + %int1024_5904 = torch.constant.int 1024 + %4919 = torch.prim.ListConstruct %int4_5903, %298, %int1024_5904 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4920 = torch.aten.view %4918, %4919 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %4920, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_5905 = torch.constant.int 4 + %int32_5906 = torch.constant.int 32 + %int128_5907 = torch.constant.int 128 + %4921 = torch.prim.ListConstruct %int4_5905, %298, %int32_5906, %int128_5907 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4922 = torch.aten.view %4906, %4921 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4922, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_5908 = torch.constant.int 4 + %int8_5909 = torch.constant.int 8 + %int128_5910 = torch.constant.int 128 + %4923 = torch.prim.ListConstruct %int4_5908, %298, %int8_5909, %int128_5910 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4924 = torch.aten.view %4913, %4923 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4924, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_5911 = torch.constant.int 4 + %int8_5912 = torch.constant.int 8 + %int128_5913 = torch.constant.int 128 + %4925 = torch.prim.ListConstruct %int4_5911, %298, %int8_5912, %int128_5913 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4926 = torch.aten.view %4920, %4925 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4926, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_5914 = torch.constant.int 131072 + %none_5915 = torch.constant.none + %none_5916 = torch.constant.none + %cpu_5917 = torch.constant.device "cpu" + %false_5918 = torch.constant.bool false + %4927 = torch.aten.arange %int131072_5914, %none_5915, %none_5916, %cpu_5917, %false_5918 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_5919 = torch.constant.int 0 + %int128_5920 = torch.constant.int 128 + %int2_5921 = torch.constant.int 2 + %int4_5922 = torch.constant.int 4 + %none_5923 = torch.constant.none + %cpu_5924 = torch.constant.device "cpu" + %false_5925 = torch.constant.bool false + %4928 = torch.aten.arange.start_step %int0_5919, %int128_5920, %int2_5921, %int4_5922, %none_5923, %cpu_5924, %false_5925 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_5926 = torch.constant.int 6 + %4929 = torch.prims.convert_element_type %4928, %int6_5926 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_5927 = torch.constant.int 128 + %4930 = torch.aten.div.Scalar %4929, %int128_5927 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_5928 = torch.constant.float 5.000000e+05 + %4931 = torch.aten.pow.Scalar %float5.000000e05_5928, %4930 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4932 = torch.aten.reciprocal %4931 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_5929 = torch.constant.float 1.000000e+00 + %4933 = torch.aten.mul.Scalar %4932, %float1.000000e00_5929 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %4934 = torch.aten.reciprocal %4933 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_5930 = torch.constant.float 6.2831853071795862 + %4935 = torch.aten.mul.Scalar %4934, %float6.283190e00_5930 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_5931 = torch.constant.float 8.192000e+03 + %4936 = torch.aten.gt.Scalar %4935, %float8.192000e03_5931 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_5932 = torch.constant.int 8 + %4937 = torch.aten.div.Scalar %4933, %int8_5932 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %4938 = torch.aten.where.self %4936, %4937, %4933 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4939 = torch.aten.reciprocal %4935 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_5933 = torch.constant.int 8192 + %4940 = torch.aten.mul.Scalar %4939, %int8192_5933 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_5934 = torch.constant.int 1 + %int1_5935 = torch.constant.int 1 + %4941 = torch.aten.sub.Scalar %4940, %int1_5934, %int1_5935 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_5936 = torch.constant.int 3 + %4942 = torch.aten.div.Scalar %4941, %int3_5936 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_5937 = torch.constant.int 1 + %int1_5938 = torch.constant.int 1 + %4943 = torch.aten.rsub.Scalar %4942, %int1_5937, %int1_5938 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %4944 = torch.aten.mul.Tensor %4943, %4938 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_5939 = torch.constant.int 8 + %4945 = torch.aten.div.Scalar %4944, %int8_5939 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %4946 = torch.aten.mul.Tensor %4942, %4938 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_5940 = torch.constant.int 1 + %4947 = torch.aten.add.Tensor %4945, %4946, %int1_5940 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_5941 = torch.constant.float 2.048000e+03 + %4948 = torch.aten.lt.Scalar %4935, %float2.048000e03_5941 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %4949 = torch.aten.bitwise_not %4948 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_5942 = torch.constant.float 8.192000e+03 + %4950 = torch.aten.gt.Scalar %4935, %float8.192000e03_5942 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %4951 = torch.aten.bitwise_not %4950 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %4952 = torch.aten.mul.Tensor %4949, %4951 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %4953 = torch.aten.where.self %4952, %4947, %4938 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4954 = torch.prim.ListConstruct %4953, %4953 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_5943 = torch.constant.int -1 + %4955 = torch.aten.cat %4954, %int-1_5943 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_5944 = torch.constant.int 6 + %4956 = torch.prims.convert_element_type %4955, %int6_5944 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_5945 = torch.constant.int 1 + %4957 = torch.aten.unsqueeze %4927, %int1_5945 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_5946 = torch.constant.int 6 + %4958 = torch.prims.convert_element_type %4957, %int6_5946 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_5947 = torch.constant.int 0 + %4959 = torch.aten.unsqueeze %4956, %int0_5947 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_5948 = torch.constant.int 6 + %4960 = torch.prims.convert_element_type %4959, %int6_5948 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %4961 = torch.aten.mul.Tensor %4958, %4960 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %4962 = torch.aten.cos %4961 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_5949 = torch.constant.int 5 + %4963 = torch.prims.convert_element_type %4962, %int5_5949 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %4964 = torch.aten.sin %4961 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_5950 = torch.constant.int 5 + %4965 = torch.prims.convert_element_type %4964, %int5_5950 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_5951 = torch.constant.int 0 + %int0_5952 = torch.constant.int 0 + %int1_5953 = torch.constant.int 1 + %4966 = torch.aten.slice.Tensor %4963, %int0_5951, %int0_5952, %298, %int1_5953 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4966, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_5954 = torch.constant.int 1 + %int0_5955 = torch.constant.int 0 + %int9223372036854775807_5956 = torch.constant.int 9223372036854775807 + %int1_5957 = torch.constant.int 1 + %4967 = torch.aten.slice.Tensor %4966, %int1_5954, %int0_5955, %int9223372036854775807_5956, %int1_5957 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4967, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_5958 = torch.constant.int 0 + %int0_5959 = torch.constant.int 0 + %int1_5960 = torch.constant.int 1 + %4968 = torch.aten.slice.Tensor %4965, %int0_5958, %int0_5959, %298, %int1_5960 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4968, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_5961 = torch.constant.int 1 + %int0_5962 = torch.constant.int 0 + %int9223372036854775807_5963 = torch.constant.int 9223372036854775807 + %int1_5964 = torch.constant.int 1 + %4969 = torch.aten.slice.Tensor %4968, %int1_5961, %int0_5962, %int9223372036854775807_5963, %int1_5964 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4969, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_5965 = torch.constant.int 0 + %4970 = torch.aten.unsqueeze %4967, %int0_5965 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4970, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_5966 = torch.constant.int 1 + %int0_5967 = torch.constant.int 0 + %int9223372036854775807_5968 = torch.constant.int 9223372036854775807 + %int1_5969 = torch.constant.int 1 + %4971 = torch.aten.slice.Tensor %4970, %int1_5966, %int0_5967, %int9223372036854775807_5968, %int1_5969 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4971, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_5970 = torch.constant.int 2 + %4972 = torch.aten.unsqueeze %4971, %int2_5970 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4972, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_5971 = torch.constant.int 3 + %int0_5972 = torch.constant.int 0 + %int9223372036854775807_5973 = torch.constant.int 9223372036854775807 + %int1_5974 = torch.constant.int 1 + %4973 = torch.aten.slice.Tensor %4972, %int3_5971, %int0_5972, %int9223372036854775807_5973, %int1_5974 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4973, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_5975 = torch.constant.int 4 %int1_5976 = torch.constant.int 1 %int1_5977 = torch.constant.int 1 - %4953 = torch.aten.add.Scalar %4923, %int1_5976, %int1_5977 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4953, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_5978 = torch.constant.int 4 - %4954 = torch.aten.mul.int %int4_5978, %398 : !torch.int, !torch.int -> !torch.int - %4955 = torch.prim.ListConstruct %4954 : (!torch.int) -> !torch.list - %4956 = torch.aten.view %4953, %4955 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4956, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %4957 = torch.prim.ListConstruct %4956 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_5979 = torch.constant.bool false - %4958 = torch.aten.index_put %4947, %4957, %4952, %false_5979 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %4958, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_5980 = torch.constant.int 32 - %int2_5981 = torch.constant.int 2 - %int32_5982 = torch.constant.int 32 - %int8_5983 = torch.constant.int 8 - %int128_5984 = torch.constant.int 128 - %4959 = torch.prim.ListConstruct %389, %int32_5980, %int2_5981, %int32_5982, %int8_5983, %int128_5984 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4960 = torch.aten.view %4958, %4959 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4960, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5985 = torch.constant.int 2097152 - %4961 = torch.prim.ListConstruct %389, %int2097152_5985 : (!torch.int, !torch.int) -> !torch.list - %4962 = torch.aten.view %4960, %4961 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4962, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_5986 = torch.constant.int -2 - %4963 = torch.aten.unsqueeze %4921, %int-2_5986 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4963, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_5987 = torch.constant.int 4 - %int8_5988 = torch.constant.int 8 + %int1_5978 = torch.constant.int 1 + %4974 = torch.prim.ListConstruct %int4_5975, %int1_5976, %int1_5977, %int1_5978 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4975 = torch.aten.repeat %4973, %4974 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %4975, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_5979 = torch.constant.int 0 + %4976 = torch.aten.unsqueeze %4969, %int0_5979 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4976, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_5980 = torch.constant.int 1 + %int0_5981 = torch.constant.int 0 + %int9223372036854775807_5982 = torch.constant.int 9223372036854775807 + %int1_5983 = torch.constant.int 1 + %4977 = torch.aten.slice.Tensor %4976, %int1_5980, %int0_5981, %int9223372036854775807_5982, %int1_5983 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %4977, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_5984 = torch.constant.int 2 + %4978 = torch.aten.unsqueeze %4977, %int2_5984 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4978, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_5985 = torch.constant.int 3 + %int0_5986 = torch.constant.int 0 + %int9223372036854775807_5987 = torch.constant.int 9223372036854775807 + %int1_5988 = torch.constant.int 1 + %4979 = torch.aten.slice.Tensor %4978, %int3_5985, %int0_5986, %int9223372036854775807_5987, %int1_5988 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %4979, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> %int4_5989 = torch.constant.int 4 - %int128_5990 = torch.constant.int 128 - %4964 = torch.prim.ListConstruct %int4_5987, %4906, %int8_5988, %int4_5989, %int128_5990 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_5991 = torch.constant.bool false - %4965 = torch.aten.expand %4963, %4964, %false_5991 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4965, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_5992 = torch.constant.int 0 - %4966 = torch.aten.clone %4965, %int0_5992 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4966, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_5993 = torch.constant.int 4 - %int32_5994 = torch.constant.int 32 - %int128_5995 = torch.constant.int 128 - %4967 = torch.prim.ListConstruct %int4_5993, %4906, %int32_5994, %int128_5995 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4968 = torch.aten._unsafe_view %4966, %4967 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4968, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_5996 = torch.constant.int -2 - %4969 = torch.aten.unsqueeze %4865, %int-2_5996 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4969, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_5997 = torch.constant.int 1 - %4970 = torch.aten.size.int %4859, %int1_5997 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_5998 = torch.constant.int 4 - %int8_5999 = torch.constant.int 8 - %int4_6000 = torch.constant.int 4 - %int128_6001 = torch.constant.int 128 - %4971 = torch.prim.ListConstruct %int4_5998, %4970, %int8_5999, %int4_6000, %int128_6001 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_6002 = torch.constant.bool false - %4972 = torch.aten.expand %4969, %4971, %false_6002 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4972, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_6003 = torch.constant.int 0 - %4973 = torch.aten.clone %4972, %int0_6003 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4973, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_6004 = torch.constant.int 4 - %int32_6005 = torch.constant.int 32 - %int128_6006 = torch.constant.int 128 - %4974 = torch.prim.ListConstruct %int4_6004, %4970, %int32_6005, %int128_6006 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4975 = torch.aten._unsafe_view %4973, %4974 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4975, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_6007 = torch.constant.int 1 - %int2_6008 = torch.constant.int 2 - %4976 = torch.aten.transpose.int %4893, %int1_6007, %int2_6008 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4976, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_6009 = torch.constant.int 1 + %int1_5990 = torch.constant.int 1 + %int1_5991 = torch.constant.int 1 + %int1_5992 = torch.constant.int 1 + %4980 = torch.prim.ListConstruct %int4_5989, %int1_5990, %int1_5991, %int1_5992 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4981 = torch.aten.repeat %4979, %4980 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %4981, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %4982 = torch.aten.mul.Tensor %4922, %4975 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4982, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_5993 = torch.constant.int 3 + %int0_5994 = torch.constant.int 0 + %int64_5995 = torch.constant.int 64 + %int1_5996 = torch.constant.int 1 + %4983 = torch.aten.slice.Tensor %4922, %int3_5993, %int0_5994, %int64_5995, %int1_5996 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %4983, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_5997 = torch.constant.int 3 + %int64_5998 = torch.constant.int 64 + %int9223372036854775807_5999 = torch.constant.int 9223372036854775807 + %int1_6000 = torch.constant.int 1 + %4984 = torch.aten.slice.Tensor %4922, %int3_5997, %int64_5998, %int9223372036854775807_5999, %int1_6000 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %4984, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %4985 = torch.aten.neg %4984 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %4985, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %4986 = torch.prim.ListConstruct %4985, %4983 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_6001 = torch.constant.int -1 + %4987 = torch.aten.cat %4986, %int-1_6001 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4987, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %4988 = torch.aten.mul.Tensor %4987, %4981 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4988, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_6002 = torch.constant.int 1 + %4989 = torch.aten.add.Tensor %4982, %4988, %int1_6002 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4989, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_6003 = torch.constant.int 131072 + %none_6004 = torch.constant.none + %none_6005 = torch.constant.none + %cpu_6006 = torch.constant.device "cpu" + %false_6007 = torch.constant.bool false + %4990 = torch.aten.arange %int131072_6003, %none_6004, %none_6005, %cpu_6006, %false_6007 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_6008 = torch.constant.int 0 + %int128_6009 = torch.constant.int 128 %int2_6010 = torch.constant.int 2 - %4977 = torch.aten.transpose.int %4968, %int1_6009, %int2_6010 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4977, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_6011 = torch.constant.int 1 - %int2_6012 = torch.constant.int 2 - %4978 = torch.aten.transpose.int %4975, %int1_6011, %int2_6012 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4978, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_6013 = torch.constant.float 0.000000e+00 - %true_6014 = torch.constant.bool true - %none_6015 = torch.constant.none - %none_6016 = torch.constant.none - %4979:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4976, %4977, %4978, %float0.000000e00_6013, %true_6014, %none_6015, %none_6016) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %4979#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_6017 = torch.constant.int 1 - %int2_6018 = torch.constant.int 2 - %4980 = torch.aten.transpose.int %4979#0, %int1_6017, %int2_6018 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4980, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_6019 = torch.constant.int 4 - %int4096_6020 = torch.constant.int 4096 - %4981 = torch.prim.ListConstruct %int4_6019, %4878, %int4096_6020 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4982 = torch.aten.view %4980, %4981 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4982, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6021 = torch.constant.int -2 - %int-1_6022 = torch.constant.int -1 - %4983 = torch.aten.transpose.int %212, %int-2_6021, %int-1_6022 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6023 = torch.constant.int 4 - %4984 = torch.aten.mul.int %int4_6023, %4878 : !torch.int, !torch.int -> !torch.int - %int4096_6024 = torch.constant.int 4096 - %4985 = torch.prim.ListConstruct %4984, %int4096_6024 : (!torch.int, !torch.int) -> !torch.list - %4986 = torch.aten.view %4982, %4985 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4986, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %4987 = torch.aten.mm %4986, %4983 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %4987, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_6025 = torch.constant.int 4 - %int4096_6026 = torch.constant.int 4096 - %4988 = torch.prim.ListConstruct %int4_6025, %4878, %int4096_6026 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4989 = torch.aten.view %4987, %4988 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4989, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int4_6011 = torch.constant.int 4 + %none_6012 = torch.constant.none + %cpu_6013 = torch.constant.device "cpu" + %false_6014 = torch.constant.bool false + %4991 = torch.aten.arange.start_step %int0_6008, %int128_6009, %int2_6010, %int4_6011, %none_6012, %cpu_6013, %false_6014 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_6015 = torch.constant.int 6 + %4992 = torch.prims.convert_element_type %4991, %int6_6015 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_6016 = torch.constant.int 128 + %4993 = torch.aten.div.Scalar %4992, %int128_6016 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_6017 = torch.constant.float 5.000000e+05 + %4994 = torch.aten.pow.Scalar %float5.000000e05_6017, %4993 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %4995 = torch.aten.reciprocal %4994 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_6018 = torch.constant.float 1.000000e+00 + %4996 = torch.aten.mul.Scalar %4995, %float1.000000e00_6018 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %4997 = torch.aten.reciprocal %4996 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_6019 = torch.constant.float 6.2831853071795862 + %4998 = torch.aten.mul.Scalar %4997, %float6.283190e00_6019 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_6020 = torch.constant.float 8.192000e+03 + %4999 = torch.aten.gt.Scalar %4998, %float8.192000e03_6020 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_6021 = torch.constant.int 8 + %5000 = torch.aten.div.Scalar %4996, %int8_6021 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %5001 = torch.aten.where.self %4999, %5000, %4996 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5002 = torch.aten.reciprocal %4998 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_6022 = torch.constant.int 8192 + %5003 = torch.aten.mul.Scalar %5002, %int8192_6022 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_6023 = torch.constant.int 1 + %int1_6024 = torch.constant.int 1 + %5004 = torch.aten.sub.Scalar %5003, %int1_6023, %int1_6024 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_6025 = torch.constant.int 3 + %5005 = torch.aten.div.Scalar %5004, %int3_6025 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_6026 = torch.constant.int 1 %int1_6027 = torch.constant.int 1 - %4990 = torch.aten.add.Tensor %4828, %4989, %int1_6027 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4990, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_6028 = torch.constant.int 6 - %4991 = torch.prims.convert_element_type %4990, %int6_6028 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4991, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_6029 = torch.constant.int 2 - %4992 = torch.aten.pow.Tensor_Scalar %4991, %int2_6029 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4992, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_6030 = torch.constant.int -1 - %4993 = torch.prim.ListConstruct %int-1_6030 : (!torch.int) -> !torch.list - %true_6031 = torch.constant.bool true - %none_6032 = torch.constant.none - %4994 = torch.aten.mean.dim %4992, %4993, %true_6031, %none_6032 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4994, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_6033 = torch.constant.float 9.9999997473787516E-6 + %5006 = torch.aten.rsub.Scalar %5005, %int1_6026, %int1_6027 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %5007 = torch.aten.mul.Tensor %5006, %5001 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_6028 = torch.constant.int 8 + %5008 = torch.aten.div.Scalar %5007, %int8_6028 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %5009 = torch.aten.mul.Tensor %5005, %5001 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_6029 = torch.constant.int 1 + %5010 = torch.aten.add.Tensor %5008, %5009, %int1_6029 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_6030 = torch.constant.float 2.048000e+03 + %5011 = torch.aten.lt.Scalar %4998, %float2.048000e03_6030 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %5012 = torch.aten.bitwise_not %5011 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_6031 = torch.constant.float 8.192000e+03 + %5013 = torch.aten.gt.Scalar %4998, %float8.192000e03_6031 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %5014 = torch.aten.bitwise_not %5013 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %5015 = torch.aten.mul.Tensor %5012, %5014 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %5016 = torch.aten.where.self %5015, %5010, %5001 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5017 = torch.prim.ListConstruct %5016, %5016 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_6032 = torch.constant.int -1 + %5018 = torch.aten.cat %5017, %int-1_6032 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_6033 = torch.constant.int 6 + %5019 = torch.prims.convert_element_type %5018, %int6_6033 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> %int1_6034 = torch.constant.int 1 - %4995 = torch.aten.add.Scalar %4994, %float9.999990e-06_6033, %int1_6034 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4995, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4996 = torch.aten.rsqrt %4995 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %4996, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %4997 = torch.aten.mul.Tensor %4991, %4996 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4997, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6035 = torch.constant.int 5 - %4998 = torch.prims.convert_element_type %4997, %int5_6035 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %4998, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %4999 = torch.aten.mul.Tensor %213, %4998 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %4999, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6036 = torch.constant.int 5 - %5000 = torch.prims.convert_element_type %4999, %int5_6036 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5000, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6037 = torch.constant.int -2 - %int-1_6038 = torch.constant.int -1 - %5001 = torch.aten.transpose.int %214, %int-2_6037, %int-1_6038 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6039 = torch.constant.int 4 - %5002 = torch.aten.mul.int %int4_6039, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6040 = torch.constant.int 4096 - %5003 = torch.prim.ListConstruct %5002, %int4096_6040 : (!torch.int, !torch.int) -> !torch.list - %5004 = torch.aten.view %5000, %5003 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5004, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5005 = torch.aten.mm %5004, %5001 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5005, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_6041 = torch.constant.int 4 - %int14336_6042 = torch.constant.int 14336 - %5006 = torch.prim.ListConstruct %int4_6041, %306, %int14336_6042 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5007 = torch.aten.view %5005, %5006 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5007, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %5008 = torch.aten.silu %5007 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5008, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_6043 = torch.constant.int -2 - %int-1_6044 = torch.constant.int -1 - %5009 = torch.aten.transpose.int %215, %int-2_6043, %int-1_6044 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6045 = torch.constant.int 4 - %5010 = torch.aten.mul.int %int4_6045, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6046 = torch.constant.int 4096 - %5011 = torch.prim.ListConstruct %5010, %int4096_6046 : (!torch.int, !torch.int) -> !torch.list - %5012 = torch.aten.view %5000, %5011 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5012, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5013 = torch.aten.mm %5012, %5009 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5013, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_6047 = torch.constant.int 4 - %int14336_6048 = torch.constant.int 14336 - %5014 = torch.prim.ListConstruct %int4_6047, %306, %int14336_6048 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5015 = torch.aten.view %5013, %5014 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5015, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %5016 = torch.aten.mul.Tensor %5008, %5015 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5016, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_6049 = torch.constant.int -2 - %int-1_6050 = torch.constant.int -1 - %5017 = torch.aten.transpose.int %216, %int-2_6049, %int-1_6050 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_6051 = torch.constant.int 1 - %5018 = torch.aten.size.int %5007, %int1_6051 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_6052 = torch.constant.int 4 - %5019 = torch.aten.mul.int %int4_6052, %5018 : !torch.int, !torch.int -> !torch.int - %int14336_6053 = torch.constant.int 14336 - %5020 = torch.prim.ListConstruct %5019, %int14336_6053 : (!torch.int, !torch.int) -> !torch.list - %5021 = torch.aten.view %5016, %5020 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5021, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %5022 = torch.aten.mm %5021, %5017 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5022, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_6054 = torch.constant.int 4 - %int4096_6055 = torch.constant.int 4096 - %5023 = torch.prim.ListConstruct %int4_6054, %5018, %int4096_6055 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5024 = torch.aten.view %5022, %5023 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5024, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_6056 = torch.constant.int 1 - %5025 = torch.aten.add.Tensor %4990, %5024, %int1_6056 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5025, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_6057 = torch.constant.int 6 - %5026 = torch.prims.convert_element_type %5025, %int6_6057 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5026, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_6058 = torch.constant.int 2 - %5027 = torch.aten.pow.Tensor_Scalar %5026, %int2_6058 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5027, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_6059 = torch.constant.int -1 - %5028 = torch.prim.ListConstruct %int-1_6059 : (!torch.int) -> !torch.list - %true_6060 = torch.constant.bool true - %none_6061 = torch.constant.none - %5029 = torch.aten.mean.dim %5027, %5028, %true_6060, %none_6061 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5029, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_6062 = torch.constant.float 9.9999997473787516E-6 + %5020 = torch.aten.unsqueeze %4990, %int1_6034 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_6035 = torch.constant.int 6 + %5021 = torch.prims.convert_element_type %5020, %int6_6035 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_6036 = torch.constant.int 0 + %5022 = torch.aten.unsqueeze %5019, %int0_6036 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_6037 = torch.constant.int 6 + %5023 = torch.prims.convert_element_type %5022, %int6_6037 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %5024 = torch.aten.mul.Tensor %5021, %5023 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %5025 = torch.aten.cos %5024 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_6038 = torch.constant.int 5 + %5026 = torch.prims.convert_element_type %5025, %int5_6038 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %5027 = torch.aten.sin %5024 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_6039 = torch.constant.int 5 + %5028 = torch.prims.convert_element_type %5027, %int5_6039 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_6040 = torch.constant.int 0 + %int0_6041 = torch.constant.int 0 + %int1_6042 = torch.constant.int 1 + %5029 = torch.aten.slice.Tensor %5026, %int0_6040, %int0_6041, %298, %int1_6042 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5029, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_6043 = torch.constant.int 1 + %int0_6044 = torch.constant.int 0 + %int9223372036854775807_6045 = torch.constant.int 9223372036854775807 + %int1_6046 = torch.constant.int 1 + %5030 = torch.aten.slice.Tensor %5029, %int1_6043, %int0_6044, %int9223372036854775807_6045, %int1_6046 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5030, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_6047 = torch.constant.int 0 + %int0_6048 = torch.constant.int 0 + %int1_6049 = torch.constant.int 1 + %5031 = torch.aten.slice.Tensor %5028, %int0_6047, %int0_6048, %298, %int1_6049 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5031, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_6050 = torch.constant.int 1 + %int0_6051 = torch.constant.int 0 + %int9223372036854775807_6052 = torch.constant.int 9223372036854775807 + %int1_6053 = torch.constant.int 1 + %5032 = torch.aten.slice.Tensor %5031, %int1_6050, %int0_6051, %int9223372036854775807_6052, %int1_6053 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5032, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_6054 = torch.constant.int 0 + %5033 = torch.aten.unsqueeze %5030, %int0_6054 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5033, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_6055 = torch.constant.int 1 + %int0_6056 = torch.constant.int 0 + %int9223372036854775807_6057 = torch.constant.int 9223372036854775807 + %int1_6058 = torch.constant.int 1 + %5034 = torch.aten.slice.Tensor %5033, %int1_6055, %int0_6056, %int9223372036854775807_6057, %int1_6058 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5034, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_6059 = torch.constant.int 2 + %5035 = torch.aten.unsqueeze %5034, %int2_6059 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5035, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_6060 = torch.constant.int 3 + %int0_6061 = torch.constant.int 0 + %int9223372036854775807_6062 = torch.constant.int 9223372036854775807 %int1_6063 = torch.constant.int 1 - %5030 = torch.aten.add.Scalar %5029, %float9.999990e-06_6062, %int1_6063 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5030, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5031 = torch.aten.rsqrt %5030 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5031, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5032 = torch.aten.mul.Tensor %5026, %5031 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5032, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6064 = torch.constant.int 5 - %5033 = torch.prims.convert_element_type %5032, %int5_6064 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5033, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %5034 = torch.aten.mul.Tensor %217, %5033 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5034, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6065 = torch.constant.int 5 - %5035 = torch.prims.convert_element_type %5034, %int5_6065 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5035, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6066 = torch.constant.int -2 - %int-1_6067 = torch.constant.int -1 - %5036 = torch.aten.transpose.int %218, %int-2_6066, %int-1_6067 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6068 = torch.constant.int 4 - %5037 = torch.aten.mul.int %int4_6068, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6069 = torch.constant.int 4096 - %5038 = torch.prim.ListConstruct %5037, %int4096_6069 : (!torch.int, !torch.int) -> !torch.list - %5039 = torch.aten.view %5035, %5038 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5039, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5040 = torch.aten.mm %5039, %5036 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5040, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_6070 = torch.constant.int 4 - %int4096_6071 = torch.constant.int 4096 - %5041 = torch.prim.ListConstruct %int4_6070, %306, %int4096_6071 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5042 = torch.aten.view %5040, %5041 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5042, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6072 = torch.constant.int -2 - %int-1_6073 = torch.constant.int -1 - %5043 = torch.aten.transpose.int %219, %int-2_6072, %int-1_6073 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6074 = torch.constant.int 4 - %5044 = torch.aten.mul.int %int4_6074, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6075 = torch.constant.int 4096 - %5045 = torch.prim.ListConstruct %5044, %int4096_6075 : (!torch.int, !torch.int) -> !torch.list - %5046 = torch.aten.view %5035, %5045 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5046, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5047 = torch.aten.mm %5046, %5043 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %5047, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_6076 = torch.constant.int 4 - %int1024_6077 = torch.constant.int 1024 - %5048 = torch.prim.ListConstruct %int4_6076, %306, %int1024_6077 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5049 = torch.aten.view %5047, %5048 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %5049, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_6078 = torch.constant.int -2 - %int-1_6079 = torch.constant.int -1 - %5050 = torch.aten.transpose.int %220, %int-2_6078, %int-1_6079 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6080 = torch.constant.int 4 - %5051 = torch.aten.mul.int %int4_6080, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6081 = torch.constant.int 4096 - %5052 = torch.prim.ListConstruct %5051, %int4096_6081 : (!torch.int, !torch.int) -> !torch.list - %5053 = torch.aten.view %5035, %5052 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5053, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5054 = torch.aten.mm %5053, %5050 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %5054, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_6082 = torch.constant.int 4 - %int1024_6083 = torch.constant.int 1024 - %5055 = torch.prim.ListConstruct %int4_6082, %306, %int1024_6083 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5056 = torch.aten.view %5054, %5055 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %5056, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_6084 = torch.constant.int 4 - %int32_6085 = torch.constant.int 32 - %int128_6086 = torch.constant.int 128 - %5057 = torch.prim.ListConstruct %int4_6084, %306, %int32_6085, %int128_6086 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5058 = torch.aten.view %5042, %5057 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5058, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_6087 = torch.constant.int 4 - %int8_6088 = torch.constant.int 8 - %int128_6089 = torch.constant.int 128 - %5059 = torch.prim.ListConstruct %int4_6087, %306, %int8_6088, %int128_6089 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5060 = torch.aten.view %5049, %5059 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5060, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_6090 = torch.constant.int 4 - %int8_6091 = torch.constant.int 8 - %int128_6092 = torch.constant.int 128 - %5061 = torch.prim.ListConstruct %int4_6090, %306, %int8_6091, %int128_6092 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5062 = torch.aten.view %5056, %5061 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5062, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_6093 = torch.constant.int 131072 - %none_6094 = torch.constant.none - %none_6095 = torch.constant.none - %cpu_6096 = torch.constant.device "cpu" - %false_6097 = torch.constant.bool false - %5063 = torch.aten.arange %int131072_6093, %none_6094, %none_6095, %cpu_6096, %false_6097 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_6098 = torch.constant.int 0 - %int128_6099 = torch.constant.int 128 - %none_6100 = torch.constant.none - %none_6101 = torch.constant.none - %cpu_6102 = torch.constant.device "cpu" - %false_6103 = torch.constant.bool false - %5064 = torch.aten.arange.start %int0_6098, %int128_6099, %none_6100, %none_6101, %cpu_6102, %false_6103 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_6104 = torch.constant.int 2 - %5065 = torch.aten.floor_divide.Scalar %5064, %int2_6104 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_6105 = torch.constant.int 6 - %5066 = torch.prims.convert_element_type %5065, %int6_6105 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_6106 = torch.constant.int 128 - %5067 = torch.aten.div.Scalar %5066, %int128_6106 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_6107 = torch.constant.float 2.000000e+00 - %5068 = torch.aten.mul.Scalar %5067, %float2.000000e00_6107 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_6108 = torch.constant.float 5.000000e+05 - %5069 = torch.aten.pow.Scalar %float5.000000e05_6108, %5068 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %5070 = torch.aten.reciprocal %5069 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_6109 = torch.constant.float 1.000000e+00 - %5071 = torch.aten.mul.Scalar %5070, %float1.000000e00_6109 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_6110 = torch.constant.int 1 - %5072 = torch.aten.unsqueeze %5063, %int1_6110 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_6111 = torch.constant.int 0 - %5073 = torch.aten.unsqueeze %5071, %int0_6111 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %5074 = torch.aten.mul.Tensor %5072, %5073 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_6112 = torch.constant.int 1 - %5075 = torch.aten.size.int %5042, %int1_6112 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_6113 = torch.constant.int 0 - %5076 = torch.aten.add.int %int0_6113, %5075 : !torch.int, !torch.int -> !torch.int - %int0_6114 = torch.constant.int 0 - %int0_6115 = torch.constant.int 0 - %int1_6116 = torch.constant.int 1 - %5077 = torch.aten.slice.Tensor %5074, %int0_6114, %int0_6115, %5076, %int1_6116 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5077, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_6117 = torch.constant.int 1 - %int0_6118 = torch.constant.int 0 - %int9223372036854775807_6119 = torch.constant.int 9223372036854775807 - %int1_6120 = torch.constant.int 1 - %5078 = torch.aten.slice.Tensor %5077, %int1_6117, %int0_6118, %int9223372036854775807_6119, %int1_6120 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5078, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_6121 = torch.constant.int 1 - %int0_6122 = torch.constant.int 0 - %int9223372036854775807_6123 = torch.constant.int 9223372036854775807 - %int1_6124 = torch.constant.int 1 - %5079 = torch.aten.slice.Tensor %5078, %int1_6121, %int0_6122, %int9223372036854775807_6123, %int1_6124 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5079, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_6125 = torch.constant.int 0 - %5080 = torch.aten.unsqueeze %5079, %int0_6125 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5080, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_6126 = torch.constant.int 1 - %int0_6127 = torch.constant.int 0 - %int9223372036854775807_6128 = torch.constant.int 9223372036854775807 - %int1_6129 = torch.constant.int 1 - %5081 = torch.aten.slice.Tensor %5080, %int1_6126, %int0_6127, %int9223372036854775807_6128, %int1_6129 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5081, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_6130 = torch.constant.int 2 - %int0_6131 = torch.constant.int 0 - %int9223372036854775807_6132 = torch.constant.int 9223372036854775807 - %int1_6133 = torch.constant.int 1 - %5082 = torch.aten.slice.Tensor %5081, %int2_6130, %int0_6131, %int9223372036854775807_6132, %int1_6133 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5082, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_6134 = torch.constant.int 4 + %5036 = torch.aten.slice.Tensor %5035, %int3_6060, %int0_6061, %int9223372036854775807_6062, %int1_6063 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5036, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_6064 = torch.constant.int 4 + %int1_6065 = torch.constant.int 1 + %int1_6066 = torch.constant.int 1 + %int1_6067 = torch.constant.int 1 + %5037 = torch.prim.ListConstruct %int4_6064, %int1_6065, %int1_6066, %int1_6067 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5038 = torch.aten.repeat %5036, %5037 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %5038, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_6068 = torch.constant.int 0 + %5039 = torch.aten.unsqueeze %5032, %int0_6068 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5039, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_6069 = torch.constant.int 1 + %int0_6070 = torch.constant.int 0 + %int9223372036854775807_6071 = torch.constant.int 9223372036854775807 + %int1_6072 = torch.constant.int 1 + %5040 = torch.aten.slice.Tensor %5039, %int1_6069, %int0_6070, %int9223372036854775807_6071, %int1_6072 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5040, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_6073 = torch.constant.int 2 + %5041 = torch.aten.unsqueeze %5040, %int2_6073 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5041, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_6074 = torch.constant.int 3 + %int0_6075 = torch.constant.int 0 + %int9223372036854775807_6076 = torch.constant.int 9223372036854775807 + %int1_6077 = torch.constant.int 1 + %5042 = torch.aten.slice.Tensor %5041, %int3_6074, %int0_6075, %int9223372036854775807_6076, %int1_6077 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5042, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_6078 = torch.constant.int 4 + %int1_6079 = torch.constant.int 1 + %int1_6080 = torch.constant.int 1 + %int1_6081 = torch.constant.int 1 + %5043 = torch.prim.ListConstruct %int4_6078, %int1_6079, %int1_6080, %int1_6081 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5044 = torch.aten.repeat %5042, %5043 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %5044, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %5045 = torch.aten.mul.Tensor %4924, %5038 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5045, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_6082 = torch.constant.int 3 + %int0_6083 = torch.constant.int 0 + %int64_6084 = torch.constant.int 64 + %int1_6085 = torch.constant.int 1 + %5046 = torch.aten.slice.Tensor %4924, %int3_6082, %int0_6083, %int64_6084, %int1_6085 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %5046, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_6086 = torch.constant.int 3 + %int64_6087 = torch.constant.int 64 + %int9223372036854775807_6088 = torch.constant.int 9223372036854775807 + %int1_6089 = torch.constant.int 1 + %5047 = torch.aten.slice.Tensor %4924, %int3_6086, %int64_6087, %int9223372036854775807_6088, %int1_6089 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %5047, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %5048 = torch.aten.neg %5047 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %5048, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %5049 = torch.prim.ListConstruct %5048, %5046 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_6090 = torch.constant.int -1 + %5050 = torch.aten.cat %5049, %int-1_6090 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5050, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %5051 = torch.aten.mul.Tensor %5050, %5044 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5051, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_6091 = torch.constant.int 1 + %5052 = torch.aten.add.Tensor %5045, %5051, %int1_6091 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5052, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_6092 = torch.constant.int 32 + %5053 = torch.aten.mul.Scalar %arg2, %int32_6092 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5053, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int17 = torch.constant.int 17 + %int1_6093 = torch.constant.int 1 + %5054 = torch.aten.add.Scalar %5053, %int17, %int1_6093 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5054, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_6094 = torch.constant.int 2 + %5055 = torch.aten.mul.Scalar %5054, %int2_6094 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5055, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_6095 = torch.constant.int 0 + %int1_6096 = torch.constant.int 1 + %5056 = torch.aten.add.Scalar %5055, %int0_6095, %int1_6096 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5056, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %5057 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %5058 = torch.aten.view %5056, %5057 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %5058, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_6097 = torch.constant.int 4 + %int32_6098 = torch.constant.int 32 + %int8_6099 = torch.constant.int 8 + %int128_6100 = torch.constant.int 128 + %5059 = torch.prim.ListConstruct %int4_6097, %296, %int32_6098, %int8_6099, %int128_6100 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5060 = torch.aten.view %5052, %5059 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5060, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_6101 = torch.constant.int 32 + %int8_6102 = torch.constant.int 8 + %int128_6103 = torch.constant.int 128 + %5061 = torch.prim.ListConstruct %504, %int32_6101, %int8_6102, %int128_6103 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5062 = torch.aten.view %5060, %5061 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %5062, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_6104 = torch.constant.int 1 + %int2_6105 = torch.constant.int 2 + %5063 = torch.aten.transpose.int %5062, %int1_6104, %int2_6105 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5063, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_6106 = torch.constant.int 5 + %5064 = torch.prims.convert_element_type %5063, %int5_6106 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5064, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_6107 = torch.constant.int 32 + %int2_6108 = torch.constant.int 2 + %int8_6109 = torch.constant.int 8 + %int32_6110 = torch.constant.int 32 + %int128_6111 = torch.constant.int 128 + %5065 = torch.prim.ListConstruct %297, %int32_6107, %int2_6108, %int8_6109, %int32_6110, %int128_6111 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5066 = torch.aten.view %4828, %5065 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5066, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_6112 = torch.constant.int 8 + %int32_6113 = torch.constant.int 32 + %int128_6114 = torch.constant.int 128 + %5067 = torch.prim.ListConstruct %497, %int8_6112, %int32_6113, %int128_6114 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5068 = torch.aten.view %5066, %5067 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5068, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %5069 = torch.prim.ListConstruct %5058 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_6115 = torch.constant.bool false + %5070 = torch.aten.index_put %5068, %5069, %5064, %false_6115 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5070, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_6116 = torch.constant.int 32 + %int2_6117 = torch.constant.int 2 + %int8_6118 = torch.constant.int 8 + %int32_6119 = torch.constant.int 32 + %int128_6120 = torch.constant.int 128 + %5071 = torch.prim.ListConstruct %297, %int32_6116, %int2_6117, %int8_6118, %int32_6119, %int128_6120 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5072 = torch.aten.view %5070, %5071 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5072, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_6121 = torch.constant.int 2097152 + %5073 = torch.prim.ListConstruct %297, %int2097152_6121 : (!torch.int, !torch.int) -> !torch.list + %5074 = torch.aten.view %5072, %5073 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5074, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_6122 = torch.constant.int 32 + %int2_6123 = torch.constant.int 2 + %int8_6124 = torch.constant.int 8 + %int32_6125 = torch.constant.int 32 + %int128_6126 = torch.constant.int 128 + %5075 = torch.prim.ListConstruct %297, %int32_6122, %int2_6123, %int8_6124, %int32_6125, %int128_6126 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5076 = torch.aten.view %5074, %5075 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5076, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_6127 = torch.constant.int 8 + %int32_6128 = torch.constant.int 32 + %int128_6129 = torch.constant.int 128 + %5077 = torch.prim.ListConstruct %497, %int8_6127, %int32_6128, %int128_6129 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5078 = torch.aten.view %5076, %5077 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5078, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_6130 = torch.constant.int 32 + %5079 = torch.aten.mul.Scalar %arg2, %int32_6130 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5079, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int17_6131 = torch.constant.int 17 + %int1_6132 = torch.constant.int 1 + %5080 = torch.aten.add.Scalar %5079, %int17_6131, %int1_6132 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5080, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_6133 = torch.constant.int 2 + %5081 = torch.aten.mul.Scalar %5080, %int2_6133 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5081, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_6134 = torch.constant.int 1 %int1_6135 = torch.constant.int 1 - %int1_6136 = torch.constant.int 1 - %5083 = torch.prim.ListConstruct %int4_6134, %int1_6135, %int1_6136 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5084 = torch.aten.repeat %5082, %5083 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %5084, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_6137 = torch.constant.int 6 - %5085 = torch.prims.convert_element_type %5058, %int6_6137 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %5085, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %5086 = torch_c.to_builtin_tensor %5085 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %5087 = torch_c.to_builtin_tensor %5084 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %5088 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%5086, %5087) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %5089 = torch_c.from_builtin_tensor %5088 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %5089, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_6138 = torch.constant.int 5 - %5090 = torch.prims.convert_element_type %5089, %int5_6138 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5090, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_6139 = torch.constant.int 131072 - %none_6140 = torch.constant.none - %none_6141 = torch.constant.none - %cpu_6142 = torch.constant.device "cpu" - %false_6143 = torch.constant.bool false - %5091 = torch.aten.arange %int131072_6139, %none_6140, %none_6141, %cpu_6142, %false_6143 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_6144 = torch.constant.int 0 - %int128_6145 = torch.constant.int 128 - %none_6146 = torch.constant.none - %none_6147 = torch.constant.none - %cpu_6148 = torch.constant.device "cpu" - %false_6149 = torch.constant.bool false - %5092 = torch.aten.arange.start %int0_6144, %int128_6145, %none_6146, %none_6147, %cpu_6148, %false_6149 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_6150 = torch.constant.int 2 - %5093 = torch.aten.floor_divide.Scalar %5092, %int2_6150 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_6151 = torch.constant.int 6 - %5094 = torch.prims.convert_element_type %5093, %int6_6151 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_6152 = torch.constant.int 128 - %5095 = torch.aten.div.Scalar %5094, %int128_6152 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_6153 = torch.constant.float 2.000000e+00 - %5096 = torch.aten.mul.Scalar %5095, %float2.000000e00_6153 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_6154 = torch.constant.float 5.000000e+05 - %5097 = torch.aten.pow.Scalar %float5.000000e05_6154, %5096 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %5098 = torch.aten.reciprocal %5097 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_6155 = torch.constant.float 1.000000e+00 - %5099 = torch.aten.mul.Scalar %5098, %float1.000000e00_6155 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_6156 = torch.constant.int 1 - %5100 = torch.aten.unsqueeze %5091, %int1_6156 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_6157 = torch.constant.int 0 - %5101 = torch.aten.unsqueeze %5099, %int0_6157 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %5102 = torch.aten.mul.Tensor %5100, %5101 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_6158 = torch.constant.int 1 - %5103 = torch.aten.size.int %5049, %int1_6158 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int + %5082 = torch.aten.add.Scalar %5081, %int1_6134, %int1_6135 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5082, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %5083 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %5084 = torch.aten.view %5082, %5083 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %5084, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_6136 = torch.constant.int 4 + %int32_6137 = torch.constant.int 32 + %int8_6138 = torch.constant.int 8 + %int128_6139 = torch.constant.int 128 + %5085 = torch.prim.ListConstruct %int4_6136, %296, %int32_6137, %int8_6138, %int128_6139 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5086 = torch.aten.view %4926, %5085 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5086, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_6140 = torch.constant.int 32 + %int8_6141 = torch.constant.int 8 + %int128_6142 = torch.constant.int 128 + %5087 = torch.prim.ListConstruct %504, %int32_6140, %int8_6141, %int128_6142 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5088 = torch.aten.view %5086, %5087 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %5088, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_6143 = torch.constant.int 1 + %int2_6144 = torch.constant.int 2 + %5089 = torch.aten.transpose.int %5088, %int1_6143, %int2_6144 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5089, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_6145 = torch.constant.int 5 + %5090 = torch.prims.convert_element_type %5089, %int5_6145 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5090, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %5091 = torch.prim.ListConstruct %5084 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_6146 = torch.constant.bool false + %5092 = torch.aten.index_put %5078, %5091, %5090, %false_6146 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5092, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_6147 = torch.constant.int 32 + %int2_6148 = torch.constant.int 2 + %int8_6149 = torch.constant.int 8 + %int32_6150 = torch.constant.int 32 + %int128_6151 = torch.constant.int 128 + %5093 = torch.prim.ListConstruct %297, %int32_6147, %int2_6148, %int8_6149, %int32_6150, %int128_6151 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5094 = torch.aten.view %5092, %5093 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5094, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_6152 = torch.constant.int 2097152 + %5095 = torch.prim.ListConstruct %297, %int2097152_6152 : (!torch.int, !torch.int) -> !torch.list + %5096 = torch.aten.view %5094, %5095 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5096, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_6153 = torch.constant.int -2 + %5097 = torch.aten.unsqueeze %5052, %int-2_6153 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5097, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_6154 = torch.constant.int 4 + %int8_6155 = torch.constant.int 8 + %int4_6156 = torch.constant.int 4 + %int128_6157 = torch.constant.int 128 + %5098 = torch.prim.ListConstruct %int4_6154, %298, %int8_6155, %int4_6156, %int128_6157 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_6158 = torch.constant.bool false + %5099 = torch.aten.expand %5097, %5098, %false_6158 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5099, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int0_6159 = torch.constant.int 0 - %5104 = torch.aten.add.int %int0_6159, %5103 : !torch.int, !torch.int -> !torch.int - %int0_6160 = torch.constant.int 0 - %int0_6161 = torch.constant.int 0 - %int1_6162 = torch.constant.int 1 - %5105 = torch.aten.slice.Tensor %5102, %int0_6160, %int0_6161, %5104, %int1_6162 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5105, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_6163 = torch.constant.int 1 - %int0_6164 = torch.constant.int 0 - %int9223372036854775807_6165 = torch.constant.int 9223372036854775807 - %int1_6166 = torch.constant.int 1 - %5106 = torch.aten.slice.Tensor %5105, %int1_6163, %int0_6164, %int9223372036854775807_6165, %int1_6166 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5106, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_6167 = torch.constant.int 1 - %int0_6168 = torch.constant.int 0 - %int9223372036854775807_6169 = torch.constant.int 9223372036854775807 - %int1_6170 = torch.constant.int 1 - %5107 = torch.aten.slice.Tensor %5106, %int1_6167, %int0_6168, %int9223372036854775807_6169, %int1_6170 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5107, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_6171 = torch.constant.int 0 - %5108 = torch.aten.unsqueeze %5107, %int0_6171 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5108, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_6172 = torch.constant.int 1 - %int0_6173 = torch.constant.int 0 - %int9223372036854775807_6174 = torch.constant.int 9223372036854775807 + %5100 = torch.aten.clone %5099, %int0_6159 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5100, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_6160 = torch.constant.int 4 + %int32_6161 = torch.constant.int 32 + %int128_6162 = torch.constant.int 128 + %5101 = torch.prim.ListConstruct %int4_6160, %298, %int32_6161, %int128_6162 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5102 = torch.aten._unsafe_view %5100, %5101 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5102, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_6163 = torch.constant.int -2 + %5103 = torch.aten.unsqueeze %4926, %int-2_6163 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5103, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_6164 = torch.constant.int 4 + %int8_6165 = torch.constant.int 8 + %int4_6166 = torch.constant.int 4 + %int128_6167 = torch.constant.int 128 + %5104 = torch.prim.ListConstruct %int4_6164, %298, %int8_6165, %int4_6166, %int128_6167 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_6168 = torch.constant.bool false + %5105 = torch.aten.expand %5103, %5104, %false_6168 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5105, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_6169 = torch.constant.int 0 + %5106 = torch.aten.clone %5105, %int0_6169 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5106, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_6170 = torch.constant.int 4 + %int32_6171 = torch.constant.int 32 + %int128_6172 = torch.constant.int 128 + %5107 = torch.prim.ListConstruct %int4_6170, %298, %int32_6171, %int128_6172 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5108 = torch.aten._unsafe_view %5106, %5107 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5108, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_6173 = torch.constant.int 1 + %int2_6174 = torch.constant.int 2 + %5109 = torch.aten.transpose.int %4989, %int1_6173, %int2_6174 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5109, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_6175 = torch.constant.int 1 - %5109 = torch.aten.slice.Tensor %5108, %int1_6172, %int0_6173, %int9223372036854775807_6174, %int1_6175 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5109, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> %int2_6176 = torch.constant.int 2 - %int0_6177 = torch.constant.int 0 - %int9223372036854775807_6178 = torch.constant.int 9223372036854775807 - %int1_6179 = torch.constant.int 1 - %5110 = torch.aten.slice.Tensor %5109, %int2_6176, %int0_6177, %int9223372036854775807_6178, %int1_6179 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5110, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_6180 = torch.constant.int 4 - %int1_6181 = torch.constant.int 1 + %5110 = torch.aten.transpose.int %5102, %int1_6175, %int2_6176 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5110, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_6177 = torch.constant.int 1 + %int2_6178 = torch.constant.int 2 + %5111 = torch.aten.transpose.int %5108, %int1_6177, %int2_6178 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5111, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_6179 = torch.constant.float 0.000000e+00 + %false_6180 = torch.constant.bool false + %none_6181 = torch.constant.none + %5112:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5109, %5110, %5111, %float0.000000e00_6179, %false_6180, %327, %none_6181) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %5112#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_6182 = torch.constant.int 1 - %5111 = torch.prim.ListConstruct %int4_6180, %int1_6181, %int1_6182 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5112 = torch.aten.repeat %5110, %5111 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %5112, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_6183 = torch.constant.int 6 - %5113 = torch.prims.convert_element_type %5060, %int6_6183 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %5113, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %5114 = torch_c.to_builtin_tensor %5113 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %5115 = torch_c.to_builtin_tensor %5112 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %5116 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%5114, %5115) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %5117 = torch_c.from_builtin_tensor %5116 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %5117, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_6184 = torch.constant.int 5 - %5118 = torch.prims.convert_element_type %5117, %int5_6184 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5118, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_6185 = torch.constant.int 64 - %5119 = torch.aten.mul.Scalar %arg2, %int64_6185 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5119, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int48 = torch.constant.int 48 - %int1_6186 = torch.constant.int 1 - %5120 = torch.aten.add.Scalar %5119, %int48, %int1_6186 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5120, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_6187 = torch.constant.int 4 - %int32_6188 = torch.constant.int 32 - %int8_6189 = torch.constant.int 8 - %int128_6190 = torch.constant.int 128 - %5121 = torch.prim.ListConstruct %int4_6187, %398, %int32_6188, %int8_6189, %int128_6190 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5122 = torch.aten.view %5118, %5121 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5122, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_6191 = torch.constant.int 4 - %5123 = torch.aten.mul.int %int4_6191, %398 : !torch.int, !torch.int -> !torch.int - %int32_6192 = torch.constant.int 32 - %int8_6193 = torch.constant.int 8 - %int128_6194 = torch.constant.int 128 - %5124 = torch.prim.ListConstruct %5123, %int32_6192, %int8_6193, %int128_6194 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5125 = torch.aten.view %5122, %5124 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5125, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_6195 = torch.constant.int 4 - %5126 = torch.aten.mul.int %int4_6195, %398 : !torch.int, !torch.int -> !torch.int - %5127 = torch.prim.ListConstruct %5126 : (!torch.int) -> !torch.list - %5128 = torch.aten.view %5120, %5127 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %5128, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_6196 = torch.constant.int 32 - %int2_6197 = torch.constant.int 2 - %int32_6198 = torch.constant.int 32 - %int8_6199 = torch.constant.int 8 - %int128_6200 = torch.constant.int 128 - %5129 = torch.prim.ListConstruct %389, %int32_6196, %int2_6197, %int32_6198, %int8_6199, %int128_6200 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5130 = torch.aten.view %4962, %5129 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5130, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_6201 = torch.constant.int 32 - %5131 = torch.aten.mul.int %389, %int32_6201 : !torch.int, !torch.int -> !torch.int - %int2_6202 = torch.constant.int 2 - %5132 = torch.aten.mul.int %5131, %int2_6202 : !torch.int, !torch.int -> !torch.int - %int32_6203 = torch.constant.int 32 - %int8_6204 = torch.constant.int 8 - %int128_6205 = torch.constant.int 128 - %5133 = torch.prim.ListConstruct %5132, %int32_6203, %int8_6204, %int128_6205 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5134 = torch.aten.view %5130, %5133 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5134, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %5135 = torch.prim.ListConstruct %5128 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_6206 = torch.constant.bool false - %5136 = torch.aten.index_put %5134, %5135, %5125, %false_6206 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5136, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_6207 = torch.constant.int 32 - %int2_6208 = torch.constant.int 2 - %int32_6209 = torch.constant.int 32 - %int8_6210 = torch.constant.int 8 - %int128_6211 = torch.constant.int 128 - %5137 = torch.prim.ListConstruct %389, %int32_6207, %int2_6208, %int32_6209, %int8_6210, %int128_6211 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5138 = torch.aten.view %5136, %5137 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5138, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6212 = torch.constant.int 2097152 - %5139 = torch.prim.ListConstruct %389, %int2097152_6212 : (!torch.int, !torch.int) -> !torch.list - %5140 = torch.aten.view %5138, %5139 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5140, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_6213 = torch.constant.int 32 - %int2_6214 = torch.constant.int 2 - %int32_6215 = torch.constant.int 32 - %int8_6216 = torch.constant.int 8 - %int128_6217 = torch.constant.int 128 - %5141 = torch.prim.ListConstruct %389, %int32_6213, %int2_6214, %int32_6215, %int8_6216, %int128_6217 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5142 = torch.aten.view %5140, %5141 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5142, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_6218 = torch.constant.int 32 - %int8_6219 = torch.constant.int 8 - %int128_6220 = torch.constant.int 128 - %5143 = torch.prim.ListConstruct %5132, %int32_6218, %int8_6219, %int128_6220 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5144 = torch.aten.view %5142, %5143 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5144, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_6221 = torch.constant.int 4 - %int32_6222 = torch.constant.int 32 - %int8_6223 = torch.constant.int 8 - %int128_6224 = torch.constant.int 128 - %5145 = torch.prim.ListConstruct %int4_6221, %398, %int32_6222, %int8_6223, %int128_6224 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5146 = torch.aten.view %5062, %5145 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5146, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_6225 = torch.constant.int 4 - %5147 = torch.aten.mul.int %int4_6225, %398 : !torch.int, !torch.int -> !torch.int - %int32_6226 = torch.constant.int 32 - %int8_6227 = torch.constant.int 8 - %int128_6228 = torch.constant.int 128 - %5148 = torch.prim.ListConstruct %5147, %int32_6226, %int8_6227, %int128_6228 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5149 = torch.aten.view %5146, %5148 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5149, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_6229 = torch.constant.int 1 - %int1_6230 = torch.constant.int 1 - %5150 = torch.aten.add.Scalar %5120, %int1_6229, %int1_6230 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5150, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_6231 = torch.constant.int 4 - %5151 = torch.aten.mul.int %int4_6231, %398 : !torch.int, !torch.int -> !torch.int - %5152 = torch.prim.ListConstruct %5151 : (!torch.int) -> !torch.list - %5153 = torch.aten.view %5150, %5152 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %5153, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %5154 = torch.prim.ListConstruct %5153 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_6232 = torch.constant.bool false - %5155 = torch.aten.index_put %5144, %5154, %5149, %false_6232 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5155, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_6233 = torch.constant.int 32 - %int2_6234 = torch.constant.int 2 - %int32_6235 = torch.constant.int 32 - %int8_6236 = torch.constant.int 8 - %int128_6237 = torch.constant.int 128 - %5156 = torch.prim.ListConstruct %389, %int32_6233, %int2_6234, %int32_6235, %int8_6236, %int128_6237 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5157 = torch.aten.view %5155, %5156 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5157, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6238 = torch.constant.int 2097152 - %5158 = torch.prim.ListConstruct %389, %int2097152_6238 : (!torch.int, !torch.int) -> !torch.list - %5159 = torch.aten.view %5157, %5158 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5159, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_6239 = torch.constant.int -2 - %5160 = torch.aten.unsqueeze %5118, %int-2_6239 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5160, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int2_6183 = torch.constant.int 2 + %5113 = torch.aten.transpose.int %5112#0, %int1_6182, %int2_6183 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5113, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_6184 = torch.constant.int 4 + %int4096_6185 = torch.constant.int 4096 + %5114 = torch.prim.ListConstruct %int4_6184, %298, %int4096_6185 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5115 = torch.aten.view %5113, %5114 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5115, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_6186 = torch.constant.int -2 + %int-1_6187 = torch.constant.int -1 + %5116 = torch.aten.transpose.int %159, %int-2_6186, %int-1_6187 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_6188 = torch.constant.int 5 + %5117 = torch.prims.convert_element_type %5116, %int5_6188 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_6189 = torch.constant.int 4096 + %5118 = torch.prim.ListConstruct %342, %int4096_6189 : (!torch.int, !torch.int) -> !torch.list + %5119 = torch.aten.view %5115, %5118 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5119, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5120 = torch.aten.mm %5119, %5117 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5120, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_6190 = torch.constant.int 4 + %int4096_6191 = torch.constant.int 4096 + %5121 = torch.prim.ListConstruct %int4_6190, %298, %int4096_6191 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5122 = torch.aten.view %5120, %5121 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5122, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_6192 = torch.constant.int 1 + %5123 = torch.aten.add.Tensor %4889, %5122, %int1_6192 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5123, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_6193 = torch.constant.int 6 + %5124 = torch.prims.convert_element_type %5123, %int6_6193 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5124, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_6194 = torch.constant.int 2 + %5125 = torch.aten.pow.Tensor_Scalar %5124, %int2_6194 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5125, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_6195 = torch.constant.int -1 + %5126 = torch.prim.ListConstruct %int-1_6195 : (!torch.int) -> !torch.list + %true_6196 = torch.constant.bool true + %none_6197 = torch.constant.none + %5127 = torch.aten.mean.dim %5125, %5126, %true_6196, %none_6197 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5127, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_6198 = torch.constant.float 9.9999997473787516E-6 + %int1_6199 = torch.constant.int 1 + %5128 = torch.aten.add.Scalar %5127, %float9.999990e-06_6198, %int1_6199 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5128, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5129 = torch.aten.rsqrt %5128 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5129, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5130 = torch.aten.mul.Tensor %5124, %5129 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5130, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_6200 = torch.constant.int 5 + %5131 = torch.prims.convert_element_type %5130, %int5_6200 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5131, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %5132 = torch.aten.mul.Tensor %160, %5131 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5132, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_6201 = torch.constant.int 5 + %5133 = torch.prims.convert_element_type %5132, %int5_6201 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5133, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_6202 = torch.constant.int -2 + %int-1_6203 = torch.constant.int -1 + %5134 = torch.aten.transpose.int %161, %int-2_6202, %int-1_6203 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_6204 = torch.constant.int 5 + %5135 = torch.prims.convert_element_type %5134, %int5_6204 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_6205 = torch.constant.int 4096 + %5136 = torch.prim.ListConstruct %342, %int4096_6205 : (!torch.int, !torch.int) -> !torch.list + %5137 = torch.aten.view %5133, %5136 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5137, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5138 = torch.aten.mm %5137, %5135 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %5138, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_6206 = torch.constant.int 4 + %int14336_6207 = torch.constant.int 14336 + %5139 = torch.prim.ListConstruct %int4_6206, %298, %int14336_6207 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5140 = torch.aten.view %5138, %5139 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5140, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %5141 = torch.aten.silu %5140 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5141, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_6208 = torch.constant.int -2 + %int-1_6209 = torch.constant.int -1 + %5142 = torch.aten.transpose.int %162, %int-2_6208, %int-1_6209 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_6210 = torch.constant.int 5 + %5143 = torch.prims.convert_element_type %5142, %int5_6210 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_6211 = torch.constant.int 4096 + %5144 = torch.prim.ListConstruct %342, %int4096_6211 : (!torch.int, !torch.int) -> !torch.list + %5145 = torch.aten.view %5133, %5144 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5145, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5146 = torch.aten.mm %5145, %5143 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %5146, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_6212 = torch.constant.int 4 + %int14336_6213 = torch.constant.int 14336 + %5147 = torch.prim.ListConstruct %int4_6212, %298, %int14336_6213 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5148 = torch.aten.view %5146, %5147 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5148, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %5149 = torch.aten.mul.Tensor %5141, %5148 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5149, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_6214 = torch.constant.int -2 + %int-1_6215 = torch.constant.int -1 + %5150 = torch.aten.transpose.int %163, %int-2_6214, %int-1_6215 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_6216 = torch.constant.int 5 + %5151 = torch.prims.convert_element_type %5150, %int5_6216 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_6217 = torch.constant.int 14336 + %5152 = torch.prim.ListConstruct %342, %int14336_6217 : (!torch.int, !torch.int) -> !torch.list + %5153 = torch.aten.view %5149, %5152 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %5153, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %5154 = torch.aten.mm %5153, %5151 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5154, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_6218 = torch.constant.int 4 + %int4096_6219 = torch.constant.int 4096 + %5155 = torch.prim.ListConstruct %int4_6218, %298, %int4096_6219 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5156 = torch.aten.view %5154, %5155 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5156, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_6220 = torch.constant.int 1 + %5157 = torch.aten.add.Tensor %5123, %5156, %int1_6220 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5157, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_6221 = torch.constant.int 6 + %5158 = torch.prims.convert_element_type %5157, %int6_6221 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5158, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_6222 = torch.constant.int 2 + %5159 = torch.aten.pow.Tensor_Scalar %5158, %int2_6222 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5159, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_6223 = torch.constant.int -1 + %5160 = torch.prim.ListConstruct %int-1_6223 : (!torch.int) -> !torch.list + %true_6224 = torch.constant.bool true + %none_6225 = torch.constant.none + %5161 = torch.aten.mean.dim %5159, %5160, %true_6224, %none_6225 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5161, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_6226 = torch.constant.float 9.9999997473787516E-6 + %int1_6227 = torch.constant.int 1 + %5162 = torch.aten.add.Scalar %5161, %float9.999990e-06_6226, %int1_6227 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5162, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5163 = torch.aten.rsqrt %5162 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5163, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5164 = torch.aten.mul.Tensor %5158, %5163 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5164, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_6228 = torch.constant.int 5 + %5165 = torch.prims.convert_element_type %5164, %int5_6228 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5165, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %5166 = torch.aten.mul.Tensor %164, %5165 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5166, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_6229 = torch.constant.int 5 + %5167 = torch.prims.convert_element_type %5166, %int5_6229 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5167, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_6230 = torch.constant.int -2 + %int-1_6231 = torch.constant.int -1 + %5168 = torch.aten.transpose.int %165, %int-2_6230, %int-1_6231 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_6232 = torch.constant.int 5 + %5169 = torch.prims.convert_element_type %5168, %int5_6232 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_6233 = torch.constant.int 4096 + %5170 = torch.prim.ListConstruct %342, %int4096_6233 : (!torch.int, !torch.int) -> !torch.list + %5171 = torch.aten.view %5167, %5170 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5171, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5172 = torch.aten.mm %5171, %5169 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5172, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_6234 = torch.constant.int 4 + %int4096_6235 = torch.constant.int 4096 + %5173 = torch.prim.ListConstruct %int4_6234, %298, %int4096_6235 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5174 = torch.aten.view %5172, %5173 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5174, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_6236 = torch.constant.int -2 + %int-1_6237 = torch.constant.int -1 + %5175 = torch.aten.transpose.int %166, %int-2_6236, %int-1_6237 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6238 = torch.constant.int 5 + %5176 = torch.prims.convert_element_type %5175, %int5_6238 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_6239 = torch.constant.int 4096 + %5177 = torch.prim.ListConstruct %342, %int4096_6239 : (!torch.int, !torch.int) -> !torch.list + %5178 = torch.aten.view %5167, %5177 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5178, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5179 = torch.aten.mm %5178, %5176 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %5179, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> %int4_6240 = torch.constant.int 4 - %int8_6241 = torch.constant.int 8 - %int4_6242 = torch.constant.int 4 - %int128_6243 = torch.constant.int 128 - %5161 = torch.prim.ListConstruct %int4_6240, %5103, %int8_6241, %int4_6242, %int128_6243 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_6244 = torch.constant.bool false - %5162 = torch.aten.expand %5160, %5161, %false_6244 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5162, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_6245 = torch.constant.int 0 - %5163 = torch.aten.clone %5162, %int0_6245 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5163, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int1024_6241 = torch.constant.int 1024 + %5180 = torch.prim.ListConstruct %int4_6240, %298, %int1024_6241 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5181 = torch.aten.view %5179, %5180 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %5181, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_6242 = torch.constant.int -2 + %int-1_6243 = torch.constant.int -1 + %5182 = torch.aten.transpose.int %167, %int-2_6242, %int-1_6243 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6244 = torch.constant.int 5 + %5183 = torch.prims.convert_element_type %5182, %int5_6244 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_6245 = torch.constant.int 4096 + %5184 = torch.prim.ListConstruct %342, %int4096_6245 : (!torch.int, !torch.int) -> !torch.list + %5185 = torch.aten.view %5167, %5184 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5185, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5186 = torch.aten.mm %5185, %5183 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %5186, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> %int4_6246 = torch.constant.int 4 - %int32_6247 = torch.constant.int 32 - %int128_6248 = torch.constant.int 128 - %5164 = torch.prim.ListConstruct %int4_6246, %5103, %int32_6247, %int128_6248 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5165 = torch.aten._unsafe_view %5163, %5164 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5165, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_6249 = torch.constant.int -2 - %5166 = torch.aten.unsqueeze %5062, %int-2_6249 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5166, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_6250 = torch.constant.int 1 - %5167 = torch.aten.size.int %5056, %int1_6250 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int + %int1024_6247 = torch.constant.int 1024 + %5187 = torch.prim.ListConstruct %int4_6246, %298, %int1024_6247 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5188 = torch.aten.view %5186, %5187 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %5188, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_6248 = torch.constant.int 4 + %int32_6249 = torch.constant.int 32 + %int128_6250 = torch.constant.int 128 + %5189 = torch.prim.ListConstruct %int4_6248, %298, %int32_6249, %int128_6250 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5190 = torch.aten.view %5174, %5189 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5190, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int4_6251 = torch.constant.int 4 %int8_6252 = torch.constant.int 8 - %int4_6253 = torch.constant.int 4 - %int128_6254 = torch.constant.int 128 - %5168 = torch.prim.ListConstruct %int4_6251, %5167, %int8_6252, %int4_6253, %int128_6254 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_6255 = torch.constant.bool false - %5169 = torch.aten.expand %5166, %5168, %false_6255 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5169, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_6256 = torch.constant.int 0 - %5170 = torch.aten.clone %5169, %int0_6256 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5170, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_6257 = torch.constant.int 4 - %int32_6258 = torch.constant.int 32 - %int128_6259 = torch.constant.int 128 - %5171 = torch.prim.ListConstruct %int4_6257, %5167, %int32_6258, %int128_6259 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5172 = torch.aten._unsafe_view %5170, %5171 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5172, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_6260 = torch.constant.int 1 - %int2_6261 = torch.constant.int 2 - %5173 = torch.aten.transpose.int %5090, %int1_6260, %int2_6261 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5173, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_6262 = torch.constant.int 1 - %int2_6263 = torch.constant.int 2 - %5174 = torch.aten.transpose.int %5165, %int1_6262, %int2_6263 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5174, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_6264 = torch.constant.int 1 - %int2_6265 = torch.constant.int 2 - %5175 = torch.aten.transpose.int %5172, %int1_6264, %int2_6265 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5175, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_6266 = torch.constant.float 0.000000e+00 - %true_6267 = torch.constant.bool true - %none_6268 = torch.constant.none - %none_6269 = torch.constant.none - %5176:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5173, %5174, %5175, %float0.000000e00_6266, %true_6267, %none_6268, %none_6269) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %5176#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_6270 = torch.constant.int 1 - %int2_6271 = torch.constant.int 2 - %5177 = torch.aten.transpose.int %5176#0, %int1_6270, %int2_6271 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5177, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_6272 = torch.constant.int 4 - %int4096_6273 = torch.constant.int 4096 - %5178 = torch.prim.ListConstruct %int4_6272, %5075, %int4096_6273 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5179 = torch.aten.view %5177, %5178 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5179, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6274 = torch.constant.int -2 - %int-1_6275 = torch.constant.int -1 - %5180 = torch.aten.transpose.int %221, %int-2_6274, %int-1_6275 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6276 = torch.constant.int 4 - %5181 = torch.aten.mul.int %int4_6276, %5075 : !torch.int, !torch.int -> !torch.int - %int4096_6277 = torch.constant.int 4096 - %5182 = torch.prim.ListConstruct %5181, %int4096_6277 : (!torch.int, !torch.int) -> !torch.list - %5183 = torch.aten.view %5179, %5182 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5183, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5184 = torch.aten.mm %5183, %5180 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5184, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_6278 = torch.constant.int 4 - %int4096_6279 = torch.constant.int 4096 - %5185 = torch.prim.ListConstruct %int4_6278, %5075, %int4096_6279 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5186 = torch.aten.view %5184, %5185 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5186, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int128_6253 = torch.constant.int 128 + %5191 = torch.prim.ListConstruct %int4_6251, %298, %int8_6252, %int128_6253 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5192 = torch.aten.view %5181, %5191 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5192, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_6254 = torch.constant.int 4 + %int8_6255 = torch.constant.int 8 + %int128_6256 = torch.constant.int 128 + %5193 = torch.prim.ListConstruct %int4_6254, %298, %int8_6255, %int128_6256 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5194 = torch.aten.view %5188, %5193 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5194, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_6257 = torch.constant.int 131072 + %none_6258 = torch.constant.none + %none_6259 = torch.constant.none + %cpu_6260 = torch.constant.device "cpu" + %false_6261 = torch.constant.bool false + %5195 = torch.aten.arange %int131072_6257, %none_6258, %none_6259, %cpu_6260, %false_6261 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_6262 = torch.constant.int 0 + %int128_6263 = torch.constant.int 128 + %int2_6264 = torch.constant.int 2 + %int4_6265 = torch.constant.int 4 + %none_6266 = torch.constant.none + %cpu_6267 = torch.constant.device "cpu" + %false_6268 = torch.constant.bool false + %5196 = torch.aten.arange.start_step %int0_6262, %int128_6263, %int2_6264, %int4_6265, %none_6266, %cpu_6267, %false_6268 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_6269 = torch.constant.int 6 + %5197 = torch.prims.convert_element_type %5196, %int6_6269 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_6270 = torch.constant.int 128 + %5198 = torch.aten.div.Scalar %5197, %int128_6270 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_6271 = torch.constant.float 5.000000e+05 + %5199 = torch.aten.pow.Scalar %float5.000000e05_6271, %5198 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5200 = torch.aten.reciprocal %5199 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_6272 = torch.constant.float 1.000000e+00 + %5201 = torch.aten.mul.Scalar %5200, %float1.000000e00_6272 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %5202 = torch.aten.reciprocal %5201 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_6273 = torch.constant.float 6.2831853071795862 + %5203 = torch.aten.mul.Scalar %5202, %float6.283190e00_6273 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_6274 = torch.constant.float 8.192000e+03 + %5204 = torch.aten.gt.Scalar %5203, %float8.192000e03_6274 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_6275 = torch.constant.int 8 + %5205 = torch.aten.div.Scalar %5201, %int8_6275 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %5206 = torch.aten.where.self %5204, %5205, %5201 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5207 = torch.aten.reciprocal %5203 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_6276 = torch.constant.int 8192 + %5208 = torch.aten.mul.Scalar %5207, %int8192_6276 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_6277 = torch.constant.int 1 + %int1_6278 = torch.constant.int 1 + %5209 = torch.aten.sub.Scalar %5208, %int1_6277, %int1_6278 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_6279 = torch.constant.int 3 + %5210 = torch.aten.div.Scalar %5209, %int3_6279 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_6280 = torch.constant.int 1 - %5187 = torch.aten.add.Tensor %5025, %5186, %int1_6280 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5187, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_6281 = torch.constant.int 6 - %5188 = torch.prims.convert_element_type %5187, %int6_6281 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5188, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_6282 = torch.constant.int 2 - %5189 = torch.aten.pow.Tensor_Scalar %5188, %int2_6282 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5189, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_6283 = torch.constant.int -1 - %5190 = torch.prim.ListConstruct %int-1_6283 : (!torch.int) -> !torch.list - %true_6284 = torch.constant.bool true - %none_6285 = torch.constant.none - %5191 = torch.aten.mean.dim %5189, %5190, %true_6284, %none_6285 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5191, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_6286 = torch.constant.float 9.9999997473787516E-6 - %int1_6287 = torch.constant.int 1 - %5192 = torch.aten.add.Scalar %5191, %float9.999990e-06_6286, %int1_6287 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5192, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5193 = torch.aten.rsqrt %5192 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5193, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5194 = torch.aten.mul.Tensor %5188, %5193 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5194, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6288 = torch.constant.int 5 - %5195 = torch.prims.convert_element_type %5194, %int5_6288 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5195, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %5196 = torch.aten.mul.Tensor %222, %5195 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5196, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6289 = torch.constant.int 5 - %5197 = torch.prims.convert_element_type %5196, %int5_6289 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5197, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6290 = torch.constant.int -2 - %int-1_6291 = torch.constant.int -1 - %5198 = torch.aten.transpose.int %223, %int-2_6290, %int-1_6291 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6292 = torch.constant.int 4 - %5199 = torch.aten.mul.int %int4_6292, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6293 = torch.constant.int 4096 - %5200 = torch.prim.ListConstruct %5199, %int4096_6293 : (!torch.int, !torch.int) -> !torch.list - %5201 = torch.aten.view %5197, %5200 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5201, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5202 = torch.aten.mm %5201, %5198 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5202, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_6294 = torch.constant.int 4 - %int14336_6295 = torch.constant.int 14336 - %5203 = torch.prim.ListConstruct %int4_6294, %306, %int14336_6295 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5204 = torch.aten.view %5202, %5203 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5204, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %5205 = torch.aten.silu %5204 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5205, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_6296 = torch.constant.int -2 - %int-1_6297 = torch.constant.int -1 - %5206 = torch.aten.transpose.int %224, %int-2_6296, %int-1_6297 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6298 = torch.constant.int 4 - %5207 = torch.aten.mul.int %int4_6298, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6299 = torch.constant.int 4096 - %5208 = torch.prim.ListConstruct %5207, %int4096_6299 : (!torch.int, !torch.int) -> !torch.list - %5209 = torch.aten.view %5197, %5208 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5209, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5210 = torch.aten.mm %5209, %5206 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5210, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_6300 = torch.constant.int 4 - %int14336_6301 = torch.constant.int 14336 - %5211 = torch.prim.ListConstruct %int4_6300, %306, %int14336_6301 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5212 = torch.aten.view %5210, %5211 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5212, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %5213 = torch.aten.mul.Tensor %5205, %5212 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5213, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_6302 = torch.constant.int -2 - %int-1_6303 = torch.constant.int -1 - %5214 = torch.aten.transpose.int %225, %int-2_6302, %int-1_6303 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int1_6281 = torch.constant.int 1 + %5211 = torch.aten.rsub.Scalar %5210, %int1_6280, %int1_6281 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %5212 = torch.aten.mul.Tensor %5211, %5206 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_6282 = torch.constant.int 8 + %5213 = torch.aten.div.Scalar %5212, %int8_6282 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %5214 = torch.aten.mul.Tensor %5210, %5206 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_6283 = torch.constant.int 1 + %5215 = torch.aten.add.Tensor %5213, %5214, %int1_6283 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_6284 = torch.constant.float 2.048000e+03 + %5216 = torch.aten.lt.Scalar %5203, %float2.048000e03_6284 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %5217 = torch.aten.bitwise_not %5216 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_6285 = torch.constant.float 8.192000e+03 + %5218 = torch.aten.gt.Scalar %5203, %float8.192000e03_6285 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %5219 = torch.aten.bitwise_not %5218 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %5220 = torch.aten.mul.Tensor %5217, %5219 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %5221 = torch.aten.where.self %5220, %5215, %5206 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5222 = torch.prim.ListConstruct %5221, %5221 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_6286 = torch.constant.int -1 + %5223 = torch.aten.cat %5222, %int-1_6286 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_6287 = torch.constant.int 6 + %5224 = torch.prims.convert_element_type %5223, %int6_6287 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_6288 = torch.constant.int 1 + %5225 = torch.aten.unsqueeze %5195, %int1_6288 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_6289 = torch.constant.int 6 + %5226 = torch.prims.convert_element_type %5225, %int6_6289 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_6290 = torch.constant.int 0 + %5227 = torch.aten.unsqueeze %5224, %int0_6290 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_6291 = torch.constant.int 6 + %5228 = torch.prims.convert_element_type %5227, %int6_6291 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %5229 = torch.aten.mul.Tensor %5226, %5228 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %5230 = torch.aten.cos %5229 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_6292 = torch.constant.int 5 + %5231 = torch.prims.convert_element_type %5230, %int5_6292 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %5232 = torch.aten.sin %5229 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_6293 = torch.constant.int 5 + %5233 = torch.prims.convert_element_type %5232, %int5_6293 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_6294 = torch.constant.int 0 + %int0_6295 = torch.constant.int 0 + %int1_6296 = torch.constant.int 1 + %5234 = torch.aten.slice.Tensor %5231, %int0_6294, %int0_6295, %298, %int1_6296 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5234, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_6297 = torch.constant.int 1 + %int0_6298 = torch.constant.int 0 + %int9223372036854775807_6299 = torch.constant.int 9223372036854775807 + %int1_6300 = torch.constant.int 1 + %5235 = torch.aten.slice.Tensor %5234, %int1_6297, %int0_6298, %int9223372036854775807_6299, %int1_6300 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5235, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_6301 = torch.constant.int 0 + %int0_6302 = torch.constant.int 0 + %int1_6303 = torch.constant.int 1 + %5236 = torch.aten.slice.Tensor %5233, %int0_6301, %int0_6302, %298, %int1_6303 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5236, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_6304 = torch.constant.int 1 - %5215 = torch.aten.size.int %5204, %int1_6304 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_6305 = torch.constant.int 4 - %5216 = torch.aten.mul.int %int4_6305, %5215 : !torch.int, !torch.int -> !torch.int - %int14336_6306 = torch.constant.int 14336 - %5217 = torch.prim.ListConstruct %5216, %int14336_6306 : (!torch.int, !torch.int) -> !torch.list - %5218 = torch.aten.view %5213, %5217 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5218, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %5219 = torch.aten.mm %5218, %5214 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5219, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_6307 = torch.constant.int 4 - %int4096_6308 = torch.constant.int 4096 - %5220 = torch.prim.ListConstruct %int4_6307, %5215, %int4096_6308 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5221 = torch.aten.view %5219, %5220 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5221, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int0_6305 = torch.constant.int 0 + %int9223372036854775807_6306 = torch.constant.int 9223372036854775807 + %int1_6307 = torch.constant.int 1 + %5237 = torch.aten.slice.Tensor %5236, %int1_6304, %int0_6305, %int9223372036854775807_6306, %int1_6307 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5237, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_6308 = torch.constant.int 0 + %5238 = torch.aten.unsqueeze %5235, %int0_6308 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5238, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int1_6309 = torch.constant.int 1 - %5222 = torch.aten.add.Tensor %5187, %5221, %int1_6309 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5222, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_6310 = torch.constant.int 6 - %5223 = torch.prims.convert_element_type %5222, %int6_6310 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5223, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_6311 = torch.constant.int 2 - %5224 = torch.aten.pow.Tensor_Scalar %5223, %int2_6311 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5224, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_6312 = torch.constant.int -1 - %5225 = torch.prim.ListConstruct %int-1_6312 : (!torch.int) -> !torch.list - %true_6313 = torch.constant.bool true - %none_6314 = torch.constant.none - %5226 = torch.aten.mean.dim %5224, %5225, %true_6313, %none_6314 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5226, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_6315 = torch.constant.float 9.9999997473787516E-6 - %int1_6316 = torch.constant.int 1 - %5227 = torch.aten.add.Scalar %5226, %float9.999990e-06_6315, %int1_6316 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5227, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5228 = torch.aten.rsqrt %5227 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5228, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5229 = torch.aten.mul.Tensor %5223, %5228 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5229, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6317 = torch.constant.int 5 - %5230 = torch.prims.convert_element_type %5229, %int5_6317 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5230, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %5231 = torch.aten.mul.Tensor %226, %5230 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5231, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6318 = torch.constant.int 5 - %5232 = torch.prims.convert_element_type %5231, %int5_6318 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5232, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6319 = torch.constant.int -2 - %int-1_6320 = torch.constant.int -1 - %5233 = torch.aten.transpose.int %227, %int-2_6319, %int-1_6320 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6321 = torch.constant.int 4 - %5234 = torch.aten.mul.int %int4_6321, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6322 = torch.constant.int 4096 - %5235 = torch.prim.ListConstruct %5234, %int4096_6322 : (!torch.int, !torch.int) -> !torch.list - %5236 = torch.aten.view %5232, %5235 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5236, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5237 = torch.aten.mm %5236, %5233 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5237, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_6323 = torch.constant.int 4 - %int4096_6324 = torch.constant.int 4096 - %5238 = torch.prim.ListConstruct %int4_6323, %306, %int4096_6324 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5239 = torch.aten.view %5237, %5238 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5239, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6325 = torch.constant.int -2 - %int-1_6326 = torch.constant.int -1 - %5240 = torch.aten.transpose.int %228, %int-2_6325, %int-1_6326 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6327 = torch.constant.int 4 - %5241 = torch.aten.mul.int %int4_6327, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6328 = torch.constant.int 4096 - %5242 = torch.prim.ListConstruct %5241, %int4096_6328 : (!torch.int, !torch.int) -> !torch.list - %5243 = torch.aten.view %5232, %5242 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5243, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5244 = torch.aten.mm %5243, %5240 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %5244, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_6329 = torch.constant.int 4 - %int1024_6330 = torch.constant.int 1024 - %5245 = torch.prim.ListConstruct %int4_6329, %306, %int1024_6330 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5246 = torch.aten.view %5244, %5245 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %5246, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_6331 = torch.constant.int -2 - %int-1_6332 = torch.constant.int -1 - %5247 = torch.aten.transpose.int %229, %int-2_6331, %int-1_6332 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6333 = torch.constant.int 4 - %5248 = torch.aten.mul.int %int4_6333, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6334 = torch.constant.int 4096 - %5249 = torch.prim.ListConstruct %5248, %int4096_6334 : (!torch.int, !torch.int) -> !torch.list - %5250 = torch.aten.view %5232, %5249 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5250, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5251 = torch.aten.mm %5250, %5247 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %5251, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_6335 = torch.constant.int 4 - %int1024_6336 = torch.constant.int 1024 - %5252 = torch.prim.ListConstruct %int4_6335, %306, %int1024_6336 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5253 = torch.aten.view %5251, %5252 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %5253, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_6337 = torch.constant.int 4 - %int32_6338 = torch.constant.int 32 - %int128_6339 = torch.constant.int 128 - %5254 = torch.prim.ListConstruct %int4_6337, %306, %int32_6338, %int128_6339 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5255 = torch.aten.view %5239, %5254 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5255, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_6340 = torch.constant.int 4 - %int8_6341 = torch.constant.int 8 - %int128_6342 = torch.constant.int 128 - %5256 = torch.prim.ListConstruct %int4_6340, %306, %int8_6341, %int128_6342 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5257 = torch.aten.view %5246, %5256 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5257, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_6343 = torch.constant.int 4 - %int8_6344 = torch.constant.int 8 - %int128_6345 = torch.constant.int 128 - %5258 = torch.prim.ListConstruct %int4_6343, %306, %int8_6344, %int128_6345 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5259 = torch.aten.view %5253, %5258 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5259, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int0_6310 = torch.constant.int 0 + %int9223372036854775807_6311 = torch.constant.int 9223372036854775807 + %int1_6312 = torch.constant.int 1 + %5239 = torch.aten.slice.Tensor %5238, %int1_6309, %int0_6310, %int9223372036854775807_6311, %int1_6312 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5239, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_6313 = torch.constant.int 2 + %5240 = torch.aten.unsqueeze %5239, %int2_6313 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5240, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_6314 = torch.constant.int 3 + %int0_6315 = torch.constant.int 0 + %int9223372036854775807_6316 = torch.constant.int 9223372036854775807 + %int1_6317 = torch.constant.int 1 + %5241 = torch.aten.slice.Tensor %5240, %int3_6314, %int0_6315, %int9223372036854775807_6316, %int1_6317 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5241, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_6318 = torch.constant.int 4 + %int1_6319 = torch.constant.int 1 + %int1_6320 = torch.constant.int 1 + %int1_6321 = torch.constant.int 1 + %5242 = torch.prim.ListConstruct %int4_6318, %int1_6319, %int1_6320, %int1_6321 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5243 = torch.aten.repeat %5241, %5242 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %5243, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_6322 = torch.constant.int 0 + %5244 = torch.aten.unsqueeze %5237, %int0_6322 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5244, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_6323 = torch.constant.int 1 + %int0_6324 = torch.constant.int 0 + %int9223372036854775807_6325 = torch.constant.int 9223372036854775807 + %int1_6326 = torch.constant.int 1 + %5245 = torch.aten.slice.Tensor %5244, %int1_6323, %int0_6324, %int9223372036854775807_6325, %int1_6326 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5245, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_6327 = torch.constant.int 2 + %5246 = torch.aten.unsqueeze %5245, %int2_6327 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5246, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_6328 = torch.constant.int 3 + %int0_6329 = torch.constant.int 0 + %int9223372036854775807_6330 = torch.constant.int 9223372036854775807 + %int1_6331 = torch.constant.int 1 + %5247 = torch.aten.slice.Tensor %5246, %int3_6328, %int0_6329, %int9223372036854775807_6330, %int1_6331 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5247, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_6332 = torch.constant.int 4 + %int1_6333 = torch.constant.int 1 + %int1_6334 = torch.constant.int 1 + %int1_6335 = torch.constant.int 1 + %5248 = torch.prim.ListConstruct %int4_6332, %int1_6333, %int1_6334, %int1_6335 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5249 = torch.aten.repeat %5247, %5248 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %5249, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %5250 = torch.aten.mul.Tensor %5190, %5243 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5250, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_6336 = torch.constant.int 3 + %int0_6337 = torch.constant.int 0 + %int64_6338 = torch.constant.int 64 + %int1_6339 = torch.constant.int 1 + %5251 = torch.aten.slice.Tensor %5190, %int3_6336, %int0_6337, %int64_6338, %int1_6339 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %5251, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_6340 = torch.constant.int 3 + %int64_6341 = torch.constant.int 64 + %int9223372036854775807_6342 = torch.constant.int 9223372036854775807 + %int1_6343 = torch.constant.int 1 + %5252 = torch.aten.slice.Tensor %5190, %int3_6340, %int64_6341, %int9223372036854775807_6342, %int1_6343 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %5252, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %5253 = torch.aten.neg %5252 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %5253, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %5254 = torch.prim.ListConstruct %5253, %5251 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_6344 = torch.constant.int -1 + %5255 = torch.aten.cat %5254, %int-1_6344 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5255, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %5256 = torch.aten.mul.Tensor %5255, %5249 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5256, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_6345 = torch.constant.int 1 + %5257 = torch.aten.add.Tensor %5250, %5256, %int1_6345 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5257, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int131072_6346 = torch.constant.int 131072 %none_6347 = torch.constant.none %none_6348 = torch.constant.none %cpu_6349 = torch.constant.device "cpu" %false_6350 = torch.constant.bool false - %5260 = torch.aten.arange %int131072_6346, %none_6347, %none_6348, %cpu_6349, %false_6350 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %5258 = torch.aten.arange %int131072_6346, %none_6347, %none_6348, %cpu_6349, %false_6350 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> %int0_6351 = torch.constant.int 0 %int128_6352 = torch.constant.int 128 - %none_6353 = torch.constant.none - %none_6354 = torch.constant.none - %cpu_6355 = torch.constant.device "cpu" - %false_6356 = torch.constant.bool false - %5261 = torch.aten.arange.start %int0_6351, %int128_6352, %none_6353, %none_6354, %cpu_6355, %false_6356 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_6357 = torch.constant.int 2 - %5262 = torch.aten.floor_divide.Scalar %5261, %int2_6357 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> + %int2_6353 = torch.constant.int 2 + %int4_6354 = torch.constant.int 4 + %none_6355 = torch.constant.none + %cpu_6356 = torch.constant.device "cpu" + %false_6357 = torch.constant.bool false + %5259 = torch.aten.arange.start_step %int0_6351, %int128_6352, %int2_6353, %int4_6354, %none_6355, %cpu_6356, %false_6357 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> %int6_6358 = torch.constant.int 6 - %5263 = torch.prims.convert_element_type %5262, %int6_6358 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> + %5260 = torch.prims.convert_element_type %5259, %int6_6358 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> %int128_6359 = torch.constant.int 128 - %5264 = torch.aten.div.Scalar %5263, %int128_6359 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_6360 = torch.constant.float 2.000000e+00 - %5265 = torch.aten.mul.Scalar %5264, %float2.000000e00_6360 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_6361 = torch.constant.float 5.000000e+05 - %5266 = torch.aten.pow.Scalar %float5.000000e05_6361, %5265 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %5267 = torch.aten.reciprocal %5266 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_6362 = torch.constant.float 1.000000e+00 - %5268 = torch.aten.mul.Scalar %5267, %float1.000000e00_6362 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_6363 = torch.constant.int 1 - %5269 = torch.aten.unsqueeze %5260, %int1_6363 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_6364 = torch.constant.int 0 - %5270 = torch.aten.unsqueeze %5268, %int0_6364 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %5271 = torch.aten.mul.Tensor %5269, %5270 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_6365 = torch.constant.int 1 - %5272 = torch.aten.size.int %5239, %int1_6365 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_6366 = torch.constant.int 0 - %5273 = torch.aten.add.int %int0_6366, %5272 : !torch.int, !torch.int -> !torch.int - %int0_6367 = torch.constant.int 0 - %int0_6368 = torch.constant.int 0 + %5261 = torch.aten.div.Scalar %5260, %int128_6359 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_6360 = torch.constant.float 5.000000e+05 + %5262 = torch.aten.pow.Scalar %float5.000000e05_6360, %5261 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5263 = torch.aten.reciprocal %5262 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_6361 = torch.constant.float 1.000000e+00 + %5264 = torch.aten.mul.Scalar %5263, %float1.000000e00_6361 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %5265 = torch.aten.reciprocal %5264 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_6362 = torch.constant.float 6.2831853071795862 + %5266 = torch.aten.mul.Scalar %5265, %float6.283190e00_6362 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_6363 = torch.constant.float 8.192000e+03 + %5267 = torch.aten.gt.Scalar %5266, %float8.192000e03_6363 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_6364 = torch.constant.int 8 + %5268 = torch.aten.div.Scalar %5264, %int8_6364 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %5269 = torch.aten.where.self %5267, %5268, %5264 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5270 = torch.aten.reciprocal %5266 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_6365 = torch.constant.int 8192 + %5271 = torch.aten.mul.Scalar %5270, %int8192_6365 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_6366 = torch.constant.int 1 + %int1_6367 = torch.constant.int 1 + %5272 = torch.aten.sub.Scalar %5271, %int1_6366, %int1_6367 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_6368 = torch.constant.int 3 + %5273 = torch.aten.div.Scalar %5272, %int3_6368 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_6369 = torch.constant.int 1 - %5274 = torch.aten.slice.Tensor %5271, %int0_6367, %int0_6368, %5273, %int1_6369 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5274, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> %int1_6370 = torch.constant.int 1 - %int0_6371 = torch.constant.int 0 - %int9223372036854775807_6372 = torch.constant.int 9223372036854775807 - %int1_6373 = torch.constant.int 1 - %5275 = torch.aten.slice.Tensor %5274, %int1_6370, %int0_6371, %int9223372036854775807_6372, %int1_6373 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5275, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_6374 = torch.constant.int 1 - %int0_6375 = torch.constant.int 0 - %int9223372036854775807_6376 = torch.constant.int 9223372036854775807 + %5274 = torch.aten.rsub.Scalar %5273, %int1_6369, %int1_6370 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %5275 = torch.aten.mul.Tensor %5274, %5269 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_6371 = torch.constant.int 8 + %5276 = torch.aten.div.Scalar %5275, %int8_6371 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %5277 = torch.aten.mul.Tensor %5273, %5269 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_6372 = torch.constant.int 1 + %5278 = torch.aten.add.Tensor %5276, %5277, %int1_6372 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_6373 = torch.constant.float 2.048000e+03 + %5279 = torch.aten.lt.Scalar %5266, %float2.048000e03_6373 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %5280 = torch.aten.bitwise_not %5279 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_6374 = torch.constant.float 8.192000e+03 + %5281 = torch.aten.gt.Scalar %5266, %float8.192000e03_6374 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %5282 = torch.aten.bitwise_not %5281 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %5283 = torch.aten.mul.Tensor %5280, %5282 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %5284 = torch.aten.where.self %5283, %5278, %5269 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5285 = torch.prim.ListConstruct %5284, %5284 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_6375 = torch.constant.int -1 + %5286 = torch.aten.cat %5285, %int-1_6375 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_6376 = torch.constant.int 6 + %5287 = torch.prims.convert_element_type %5286, %int6_6376 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> %int1_6377 = torch.constant.int 1 - %5276 = torch.aten.slice.Tensor %5275, %int1_6374, %int0_6375, %int9223372036854775807_6376, %int1_6377 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5276, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_6378 = torch.constant.int 0 - %5277 = torch.aten.unsqueeze %5276, %int0_6378 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5277, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_6379 = torch.constant.int 1 - %int0_6380 = torch.constant.int 0 - %int9223372036854775807_6381 = torch.constant.int 9223372036854775807 - %int1_6382 = torch.constant.int 1 - %5278 = torch.aten.slice.Tensor %5277, %int1_6379, %int0_6380, %int9223372036854775807_6381, %int1_6382 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5278, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_6383 = torch.constant.int 2 + %5288 = torch.aten.unsqueeze %5258, %int1_6377 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_6378 = torch.constant.int 6 + %5289 = torch.prims.convert_element_type %5288, %int6_6378 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_6379 = torch.constant.int 0 + %5290 = torch.aten.unsqueeze %5287, %int0_6379 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_6380 = torch.constant.int 6 + %5291 = torch.prims.convert_element_type %5290, %int6_6380 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %5292 = torch.aten.mul.Tensor %5289, %5291 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %5293 = torch.aten.cos %5292 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_6381 = torch.constant.int 5 + %5294 = torch.prims.convert_element_type %5293, %int5_6381 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %5295 = torch.aten.sin %5292 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_6382 = torch.constant.int 5 + %5296 = torch.prims.convert_element_type %5295, %int5_6382 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_6383 = torch.constant.int 0 %int0_6384 = torch.constant.int 0 - %int9223372036854775807_6385 = torch.constant.int 9223372036854775807 + %int1_6385 = torch.constant.int 1 + %5297 = torch.aten.slice.Tensor %5294, %int0_6383, %int0_6384, %298, %int1_6385 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5297, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_6386 = torch.constant.int 1 - %5279 = torch.aten.slice.Tensor %5278, %int2_6383, %int0_6384, %int9223372036854775807_6385, %int1_6386 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5279, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_6387 = torch.constant.int 4 - %int1_6388 = torch.constant.int 1 + %int0_6387 = torch.constant.int 0 + %int9223372036854775807_6388 = torch.constant.int 9223372036854775807 %int1_6389 = torch.constant.int 1 - %5280 = torch.prim.ListConstruct %int4_6387, %int1_6388, %int1_6389 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5281 = torch.aten.repeat %5279, %5280 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %5281, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_6390 = torch.constant.int 6 - %5282 = torch.prims.convert_element_type %5255, %int6_6390 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %5282, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %5283 = torch_c.to_builtin_tensor %5282 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %5284 = torch_c.to_builtin_tensor %5281 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %5285 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%5283, %5284) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %5286 = torch_c.from_builtin_tensor %5285 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %5286, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_6391 = torch.constant.int 5 - %5287 = torch.prims.convert_element_type %5286, %int5_6391 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5287, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_6392 = torch.constant.int 131072 - %none_6393 = torch.constant.none - %none_6394 = torch.constant.none - %cpu_6395 = torch.constant.device "cpu" - %false_6396 = torch.constant.bool false - %5288 = torch.aten.arange %int131072_6392, %none_6393, %none_6394, %cpu_6395, %false_6396 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %5298 = torch.aten.slice.Tensor %5297, %int1_6386, %int0_6387, %int9223372036854775807_6388, %int1_6389 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5298, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_6390 = torch.constant.int 0 + %int0_6391 = torch.constant.int 0 + %int1_6392 = torch.constant.int 1 + %5299 = torch.aten.slice.Tensor %5296, %int0_6390, %int0_6391, %298, %int1_6392 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5299, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_6393 = torch.constant.int 1 + %int0_6394 = torch.constant.int 0 + %int9223372036854775807_6395 = torch.constant.int 9223372036854775807 + %int1_6396 = torch.constant.int 1 + %5300 = torch.aten.slice.Tensor %5299, %int1_6393, %int0_6394, %int9223372036854775807_6395, %int1_6396 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5300, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int0_6397 = torch.constant.int 0 - %int128_6398 = torch.constant.int 128 - %none_6399 = torch.constant.none - %none_6400 = torch.constant.none - %cpu_6401 = torch.constant.device "cpu" - %false_6402 = torch.constant.bool false - %5289 = torch.aten.arange.start %int0_6397, %int128_6398, %none_6399, %none_6400, %cpu_6401, %false_6402 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_6403 = torch.constant.int 2 - %5290 = torch.aten.floor_divide.Scalar %5289, %int2_6403 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_6404 = torch.constant.int 6 - %5291 = torch.prims.convert_element_type %5290, %int6_6404 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_6405 = torch.constant.int 128 - %5292 = torch.aten.div.Scalar %5291, %int128_6405 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_6406 = torch.constant.float 2.000000e+00 - %5293 = torch.aten.mul.Scalar %5292, %float2.000000e00_6406 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_6407 = torch.constant.float 5.000000e+05 - %5294 = torch.aten.pow.Scalar %float5.000000e05_6407, %5293 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %5295 = torch.aten.reciprocal %5294 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_6408 = torch.constant.float 1.000000e+00 - %5296 = torch.aten.mul.Scalar %5295, %float1.000000e00_6408 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %5301 = torch.aten.unsqueeze %5298, %int0_6397 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5301, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_6398 = torch.constant.int 1 + %int0_6399 = torch.constant.int 0 + %int9223372036854775807_6400 = torch.constant.int 9223372036854775807 + %int1_6401 = torch.constant.int 1 + %5302 = torch.aten.slice.Tensor %5301, %int1_6398, %int0_6399, %int9223372036854775807_6400, %int1_6401 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5302, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_6402 = torch.constant.int 2 + %5303 = torch.aten.unsqueeze %5302, %int2_6402 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5303, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_6403 = torch.constant.int 3 + %int0_6404 = torch.constant.int 0 + %int9223372036854775807_6405 = torch.constant.int 9223372036854775807 + %int1_6406 = torch.constant.int 1 + %5304 = torch.aten.slice.Tensor %5303, %int3_6403, %int0_6404, %int9223372036854775807_6405, %int1_6406 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5304, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_6407 = torch.constant.int 4 + %int1_6408 = torch.constant.int 1 %int1_6409 = torch.constant.int 1 - %5297 = torch.aten.unsqueeze %5288, %int1_6409 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_6410 = torch.constant.int 0 - %5298 = torch.aten.unsqueeze %5296, %int0_6410 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %5299 = torch.aten.mul.Tensor %5297, %5298 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_6411 = torch.constant.int 1 - %5300 = torch.aten.size.int %5246, %int1_6411 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_6412 = torch.constant.int 0 - %5301 = torch.aten.add.int %int0_6412, %5300 : !torch.int, !torch.int -> !torch.int + %int1_6410 = torch.constant.int 1 + %5305 = torch.prim.ListConstruct %int4_6407, %int1_6408, %int1_6409, %int1_6410 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5306 = torch.aten.repeat %5304, %5305 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %5306, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_6411 = torch.constant.int 0 + %5307 = torch.aten.unsqueeze %5300, %int0_6411 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5307, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_6412 = torch.constant.int 1 %int0_6413 = torch.constant.int 0 - %int0_6414 = torch.constant.int 0 + %int9223372036854775807_6414 = torch.constant.int 9223372036854775807 %int1_6415 = torch.constant.int 1 - %5302 = torch.aten.slice.Tensor %5299, %int0_6413, %int0_6414, %5301, %int1_6415 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5302, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_6416 = torch.constant.int 1 - %int0_6417 = torch.constant.int 0 - %int9223372036854775807_6418 = torch.constant.int 9223372036854775807 - %int1_6419 = torch.constant.int 1 - %5303 = torch.aten.slice.Tensor %5302, %int1_6416, %int0_6417, %int9223372036854775807_6418, %int1_6419 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5303, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %5308 = torch.aten.slice.Tensor %5307, %int1_6412, %int0_6413, %int9223372036854775807_6414, %int1_6415 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5308, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_6416 = torch.constant.int 2 + %5309 = torch.aten.unsqueeze %5308, %int2_6416 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5309, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_6417 = torch.constant.int 3 + %int0_6418 = torch.constant.int 0 + %int9223372036854775807_6419 = torch.constant.int 9223372036854775807 %int1_6420 = torch.constant.int 1 - %int0_6421 = torch.constant.int 0 - %int9223372036854775807_6422 = torch.constant.int 9223372036854775807 + %5310 = torch.aten.slice.Tensor %5309, %int3_6417, %int0_6418, %int9223372036854775807_6419, %int1_6420 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5310, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_6421 = torch.constant.int 4 + %int1_6422 = torch.constant.int 1 %int1_6423 = torch.constant.int 1 - %5304 = torch.aten.slice.Tensor %5303, %int1_6420, %int0_6421, %int9223372036854775807_6422, %int1_6423 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5304, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_6424 = torch.constant.int 0 - %5305 = torch.aten.unsqueeze %5304, %int0_6424 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5305, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_6425 = torch.constant.int 1 + %int1_6424 = torch.constant.int 1 + %5311 = torch.prim.ListConstruct %int4_6421, %int1_6422, %int1_6423, %int1_6424 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5312 = torch.aten.repeat %5310, %5311 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %5312, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %5313 = torch.aten.mul.Tensor %5192, %5306 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5313, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_6425 = torch.constant.int 3 %int0_6426 = torch.constant.int 0 - %int9223372036854775807_6427 = torch.constant.int 9223372036854775807 + %int64_6427 = torch.constant.int 64 %int1_6428 = torch.constant.int 1 - %5306 = torch.aten.slice.Tensor %5305, %int1_6425, %int0_6426, %int9223372036854775807_6427, %int1_6428 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5306, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_6429 = torch.constant.int 2 - %int0_6430 = torch.constant.int 0 + %5314 = torch.aten.slice.Tensor %5192, %int3_6425, %int0_6426, %int64_6427, %int1_6428 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %5314, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_6429 = torch.constant.int 3 + %int64_6430 = torch.constant.int 64 %int9223372036854775807_6431 = torch.constant.int 9223372036854775807 %int1_6432 = torch.constant.int 1 - %5307 = torch.aten.slice.Tensor %5306, %int2_6429, %int0_6430, %int9223372036854775807_6431, %int1_6432 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5307, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_6433 = torch.constant.int 4 + %5315 = torch.aten.slice.Tensor %5192, %int3_6429, %int64_6430, %int9223372036854775807_6431, %int1_6432 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %5315, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %5316 = torch.aten.neg %5315 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %5316, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %5317 = torch.prim.ListConstruct %5316, %5314 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_6433 = torch.constant.int -1 + %5318 = torch.aten.cat %5317, %int-1_6433 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5318, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %5319 = torch.aten.mul.Tensor %5318, %5312 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5319, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> %int1_6434 = torch.constant.int 1 - %int1_6435 = torch.constant.int 1 - %5308 = torch.prim.ListConstruct %int4_6433, %int1_6434, %int1_6435 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5309 = torch.aten.repeat %5307, %5308 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %5309, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_6436 = torch.constant.int 6 - %5310 = torch.prims.convert_element_type %5257, %int6_6436 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %5310, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %5311 = torch_c.to_builtin_tensor %5310 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %5312 = torch_c.to_builtin_tensor %5309 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %5313 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%5311, %5312) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %5314 = torch_c.from_builtin_tensor %5313 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %5314, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_6437 = torch.constant.int 5 - %5315 = torch.prims.convert_element_type %5314, %int5_6437 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5315, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_6438 = torch.constant.int 64 - %5316 = torch.aten.mul.Scalar %arg2, %int64_6438 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5316, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int50 = torch.constant.int 50 + %5320 = torch.aten.add.Tensor %5313, %5319, %int1_6434 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5320, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_6435 = torch.constant.int 32 + %5321 = torch.aten.mul.Scalar %arg2, %int32_6435 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5321, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int18 = torch.constant.int 18 + %int1_6436 = torch.constant.int 1 + %5322 = torch.aten.add.Scalar %5321, %int18, %int1_6436 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5322, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_6437 = torch.constant.int 2 + %5323 = torch.aten.mul.Scalar %5322, %int2_6437 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5323, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_6438 = torch.constant.int 0 %int1_6439 = torch.constant.int 1 - %5317 = torch.aten.add.Scalar %5316, %int50, %int1_6439 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5317, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %5324 = torch.aten.add.Scalar %5323, %int0_6438, %int1_6439 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5324, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %5325 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %5326 = torch.aten.view %5324, %5325 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %5326, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> %int4_6440 = torch.constant.int 4 %int32_6441 = torch.constant.int 32 %int8_6442 = torch.constant.int 8 %int128_6443 = torch.constant.int 128 - %5318 = torch.prim.ListConstruct %int4_6440, %398, %int32_6441, %int8_6442, %int128_6443 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5319 = torch.aten.view %5315, %5318 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5319, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_6444 = torch.constant.int 4 - %5320 = torch.aten.mul.int %int4_6444, %398 : !torch.int, !torch.int -> !torch.int - %int32_6445 = torch.constant.int 32 - %int8_6446 = torch.constant.int 8 - %int128_6447 = torch.constant.int 128 - %5321 = torch.prim.ListConstruct %5320, %int32_6445, %int8_6446, %int128_6447 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5322 = torch.aten.view %5319, %5321 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5322, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_6448 = torch.constant.int 4 - %5323 = torch.aten.mul.int %int4_6448, %398 : !torch.int, !torch.int -> !torch.int - %5324 = torch.prim.ListConstruct %5323 : (!torch.int) -> !torch.list - %5325 = torch.aten.view %5317, %5324 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %5325, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_6449 = torch.constant.int 32 - %int2_6450 = torch.constant.int 2 - %int32_6451 = torch.constant.int 32 + %5327 = torch.prim.ListConstruct %int4_6440, %296, %int32_6441, %int8_6442, %int128_6443 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5328 = torch.aten.view %5320, %5327 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5328, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_6444 = torch.constant.int 32 + %int8_6445 = torch.constant.int 8 + %int128_6446 = torch.constant.int 128 + %5329 = torch.prim.ListConstruct %504, %int32_6444, %int8_6445, %int128_6446 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5330 = torch.aten.view %5328, %5329 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %5330, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_6447 = torch.constant.int 1 + %int2_6448 = torch.constant.int 2 + %5331 = torch.aten.transpose.int %5330, %int1_6447, %int2_6448 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5331, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_6449 = torch.constant.int 5 + %5332 = torch.prims.convert_element_type %5331, %int5_6449 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5332, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_6450 = torch.constant.int 32 + %int2_6451 = torch.constant.int 2 %int8_6452 = torch.constant.int 8 - %int128_6453 = torch.constant.int 128 - %5326 = torch.prim.ListConstruct %389, %int32_6449, %int2_6450, %int32_6451, %int8_6452, %int128_6453 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5327 = torch.aten.view %5159, %5326 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5327, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_6454 = torch.constant.int 32 - %5328 = torch.aten.mul.int %389, %int32_6454 : !torch.int, !torch.int -> !torch.int - %int2_6455 = torch.constant.int 2 - %5329 = torch.aten.mul.int %5328, %int2_6455 : !torch.int, !torch.int -> !torch.int + %int32_6453 = torch.constant.int 32 + %int128_6454 = torch.constant.int 128 + %5333 = torch.prim.ListConstruct %297, %int32_6450, %int2_6451, %int8_6452, %int32_6453, %int128_6454 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5334 = torch.aten.view %5096, %5333 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5334, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_6455 = torch.constant.int 8 %int32_6456 = torch.constant.int 32 - %int8_6457 = torch.constant.int 8 - %int128_6458 = torch.constant.int 128 - %5330 = torch.prim.ListConstruct %5329, %int32_6456, %int8_6457, %int128_6458 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5331 = torch.aten.view %5327, %5330 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5331, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %5332 = torch.prim.ListConstruct %5325 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_6459 = torch.constant.bool false - %5333 = torch.aten.index_put %5331, %5332, %5322, %false_6459 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5333, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_6460 = torch.constant.int 32 - %int2_6461 = torch.constant.int 2 + %int128_6457 = torch.constant.int 128 + %5335 = torch.prim.ListConstruct %497, %int8_6455, %int32_6456, %int128_6457 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5336 = torch.aten.view %5334, %5335 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5336, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %5337 = torch.prim.ListConstruct %5326 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_6458 = torch.constant.bool false + %5338 = torch.aten.index_put %5336, %5337, %5332, %false_6458 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5338, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_6459 = torch.constant.int 32 + %int2_6460 = torch.constant.int 2 + %int8_6461 = torch.constant.int 8 %int32_6462 = torch.constant.int 32 - %int8_6463 = torch.constant.int 8 - %int128_6464 = torch.constant.int 128 - %5334 = torch.prim.ListConstruct %389, %int32_6460, %int2_6461, %int32_6462, %int8_6463, %int128_6464 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5335 = torch.aten.view %5333, %5334 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5335, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6465 = torch.constant.int 2097152 - %5336 = torch.prim.ListConstruct %389, %int2097152_6465 : (!torch.int, !torch.int) -> !torch.list - %5337 = torch.aten.view %5335, %5336 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5337, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_6466 = torch.constant.int 32 - %int2_6467 = torch.constant.int 2 + %int128_6463 = torch.constant.int 128 + %5339 = torch.prim.ListConstruct %297, %int32_6459, %int2_6460, %int8_6461, %int32_6462, %int128_6463 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5340 = torch.aten.view %5338, %5339 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5340, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_6464 = torch.constant.int 2097152 + %5341 = torch.prim.ListConstruct %297, %int2097152_6464 : (!torch.int, !torch.int) -> !torch.list + %5342 = torch.aten.view %5340, %5341 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5342, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_6465 = torch.constant.int 32 + %int2_6466 = torch.constant.int 2 + %int8_6467 = torch.constant.int 8 %int32_6468 = torch.constant.int 32 - %int8_6469 = torch.constant.int 8 - %int128_6470 = torch.constant.int 128 - %5338 = torch.prim.ListConstruct %389, %int32_6466, %int2_6467, %int32_6468, %int8_6469, %int128_6470 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5339 = torch.aten.view %5337, %5338 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5339, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> + %int128_6469 = torch.constant.int 128 + %5343 = torch.prim.ListConstruct %297, %int32_6465, %int2_6466, %int8_6467, %int32_6468, %int128_6469 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5344 = torch.aten.view %5342, %5343 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5344, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_6470 = torch.constant.int 8 %int32_6471 = torch.constant.int 32 - %int8_6472 = torch.constant.int 8 - %int128_6473 = torch.constant.int 128 - %5340 = torch.prim.ListConstruct %5329, %int32_6471, %int8_6472, %int128_6473 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5341 = torch.aten.view %5339, %5340 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5341, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_6474 = torch.constant.int 4 - %int32_6475 = torch.constant.int 32 - %int8_6476 = torch.constant.int 8 - %int128_6477 = torch.constant.int 128 - %5342 = torch.prim.ListConstruct %int4_6474, %398, %int32_6475, %int8_6476, %int128_6477 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5343 = torch.aten.view %5259, %5342 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5343, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_6478 = torch.constant.int 4 - %5344 = torch.aten.mul.int %int4_6478, %398 : !torch.int, !torch.int -> !torch.int - %int32_6479 = torch.constant.int 32 - %int8_6480 = torch.constant.int 8 - %int128_6481 = torch.constant.int 128 - %5345 = torch.prim.ListConstruct %5344, %int32_6479, %int8_6480, %int128_6481 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5346 = torch.aten.view %5343, %5345 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5346, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_6482 = torch.constant.int 1 - %int1_6483 = torch.constant.int 1 - %5347 = torch.aten.add.Scalar %5317, %int1_6482, %int1_6483 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5347, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_6484 = torch.constant.int 4 - %5348 = torch.aten.mul.int %int4_6484, %398 : !torch.int, !torch.int -> !torch.int - %5349 = torch.prim.ListConstruct %5348 : (!torch.int) -> !torch.list - %5350 = torch.aten.view %5347, %5349 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %5350, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %5351 = torch.prim.ListConstruct %5350 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_6485 = torch.constant.bool false - %5352 = torch.aten.index_put %5341, %5351, %5346, %false_6485 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5352, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_6486 = torch.constant.int 32 + %int128_6472 = torch.constant.int 128 + %5345 = torch.prim.ListConstruct %497, %int8_6470, %int32_6471, %int128_6472 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5346 = torch.aten.view %5344, %5345 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5346, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_6473 = torch.constant.int 32 + %5347 = torch.aten.mul.Scalar %arg2, %int32_6473 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5347, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int18_6474 = torch.constant.int 18 + %int1_6475 = torch.constant.int 1 + %5348 = torch.aten.add.Scalar %5347, %int18_6474, %int1_6475 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5348, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_6476 = torch.constant.int 2 + %5349 = torch.aten.mul.Scalar %5348, %int2_6476 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5349, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_6477 = torch.constant.int 1 + %int1_6478 = torch.constant.int 1 + %5350 = torch.aten.add.Scalar %5349, %int1_6477, %int1_6478 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5350, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %5351 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %5352 = torch.aten.view %5350, %5351 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %5352, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_6479 = torch.constant.int 4 + %int32_6480 = torch.constant.int 32 + %int8_6481 = torch.constant.int 8 + %int128_6482 = torch.constant.int 128 + %5353 = torch.prim.ListConstruct %int4_6479, %296, %int32_6480, %int8_6481, %int128_6482 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5354 = torch.aten.view %5194, %5353 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5354, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_6483 = torch.constant.int 32 + %int8_6484 = torch.constant.int 8 + %int128_6485 = torch.constant.int 128 + %5355 = torch.prim.ListConstruct %504, %int32_6483, %int8_6484, %int128_6485 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5356 = torch.aten.view %5354, %5355 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %5356, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_6486 = torch.constant.int 1 %int2_6487 = torch.constant.int 2 - %int32_6488 = torch.constant.int 32 - %int8_6489 = torch.constant.int 8 - %int128_6490 = torch.constant.int 128 - %5353 = torch.prim.ListConstruct %389, %int32_6486, %int2_6487, %int32_6488, %int8_6489, %int128_6490 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5354 = torch.aten.view %5352, %5353 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5354, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6491 = torch.constant.int 2097152 - %5355 = torch.prim.ListConstruct %389, %int2097152_6491 : (!torch.int, !torch.int) -> !torch.list - %5356 = torch.aten.view %5354, %5355 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5356, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_6492 = torch.constant.int -2 - %5357 = torch.aten.unsqueeze %5315, %int-2_6492 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5357, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_6493 = torch.constant.int 4 - %int8_6494 = torch.constant.int 8 - %int4_6495 = torch.constant.int 4 - %int128_6496 = torch.constant.int 128 - %5358 = torch.prim.ListConstruct %int4_6493, %5300, %int8_6494, %int4_6495, %int128_6496 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_6497 = torch.constant.bool false - %5359 = torch.aten.expand %5357, %5358, %false_6497 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5359, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_6498 = torch.constant.int 0 - %5360 = torch.aten.clone %5359, %int0_6498 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5360, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %5357 = torch.aten.transpose.int %5356, %int1_6486, %int2_6487 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5357, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_6488 = torch.constant.int 5 + %5358 = torch.prims.convert_element_type %5357, %int5_6488 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5358, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %5359 = torch.prim.ListConstruct %5352 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_6489 = torch.constant.bool false + %5360 = torch.aten.index_put %5346, %5359, %5358, %false_6489 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5360, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_6490 = torch.constant.int 32 + %int2_6491 = torch.constant.int 2 + %int8_6492 = torch.constant.int 8 + %int32_6493 = torch.constant.int 32 + %int128_6494 = torch.constant.int 128 + %5361 = torch.prim.ListConstruct %297, %int32_6490, %int2_6491, %int8_6492, %int32_6493, %int128_6494 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5362 = torch.aten.view %5360, %5361 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5362, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_6495 = torch.constant.int 2097152 + %5363 = torch.prim.ListConstruct %297, %int2097152_6495 : (!torch.int, !torch.int) -> !torch.list + %5364 = torch.aten.view %5362, %5363 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5364, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_6496 = torch.constant.int -2 + %5365 = torch.aten.unsqueeze %5320, %int-2_6496 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5365, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_6497 = torch.constant.int 4 + %int8_6498 = torch.constant.int 8 %int4_6499 = torch.constant.int 4 - %int32_6500 = torch.constant.int 32 - %int128_6501 = torch.constant.int 128 - %5361 = torch.prim.ListConstruct %int4_6499, %5300, %int32_6500, %int128_6501 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5362 = torch.aten._unsafe_view %5360, %5361 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5362, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_6502 = torch.constant.int -2 - %5363 = torch.aten.unsqueeze %5259, %int-2_6502 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5363, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_6503 = torch.constant.int 1 - %5364 = torch.aten.size.int %5253, %int1_6503 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_6504 = torch.constant.int 4 - %int8_6505 = torch.constant.int 8 - %int4_6506 = torch.constant.int 4 - %int128_6507 = torch.constant.int 128 - %5365 = torch.prim.ListConstruct %int4_6504, %5364, %int8_6505, %int4_6506, %int128_6507 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_6508 = torch.constant.bool false - %5366 = torch.aten.expand %5363, %5365, %false_6508 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5366, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_6509 = torch.constant.int 0 - %5367 = torch.aten.clone %5366, %int0_6509 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5367, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_6510 = torch.constant.int 4 - %int32_6511 = torch.constant.int 32 - %int128_6512 = torch.constant.int 128 - %5368 = torch.prim.ListConstruct %int4_6510, %5364, %int32_6511, %int128_6512 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5369 = torch.aten._unsafe_view %5367, %5368 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5369, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_6513 = torch.constant.int 1 - %int2_6514 = torch.constant.int 2 - %5370 = torch.aten.transpose.int %5287, %int1_6513, %int2_6514 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5370, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_6515 = torch.constant.int 1 - %int2_6516 = torch.constant.int 2 - %5371 = torch.aten.transpose.int %5362, %int1_6515, %int2_6516 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5371, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_6517 = torch.constant.int 1 - %int2_6518 = torch.constant.int 2 - %5372 = torch.aten.transpose.int %5369, %int1_6517, %int2_6518 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5372, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_6519 = torch.constant.float 0.000000e+00 - %true_6520 = torch.constant.bool true - %none_6521 = torch.constant.none - %none_6522 = torch.constant.none - %5373:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5370, %5371, %5372, %float0.000000e00_6519, %true_6520, %none_6521, %none_6522) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %5373#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_6523 = torch.constant.int 1 - %int2_6524 = torch.constant.int 2 - %5374 = torch.aten.transpose.int %5373#0, %int1_6523, %int2_6524 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5374, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_6525 = torch.constant.int 4 - %int4096_6526 = torch.constant.int 4096 - %5375 = torch.prim.ListConstruct %int4_6525, %5272, %int4096_6526 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5376 = torch.aten.view %5374, %5375 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5376, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6527 = torch.constant.int -2 - %int-1_6528 = torch.constant.int -1 - %5377 = torch.aten.transpose.int %230, %int-2_6527, %int-1_6528 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6529 = torch.constant.int 4 - %5378 = torch.aten.mul.int %int4_6529, %5272 : !torch.int, !torch.int -> !torch.int - %int4096_6530 = torch.constant.int 4096 - %5379 = torch.prim.ListConstruct %5378, %int4096_6530 : (!torch.int, !torch.int) -> !torch.list - %5380 = torch.aten.view %5376, %5379 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5380, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5381 = torch.aten.mm %5380, %5377 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5381, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_6531 = torch.constant.int 4 + %int128_6500 = torch.constant.int 128 + %5366 = torch.prim.ListConstruct %int4_6497, %298, %int8_6498, %int4_6499, %int128_6500 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_6501 = torch.constant.bool false + %5367 = torch.aten.expand %5365, %5366, %false_6501 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5367, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_6502 = torch.constant.int 0 + %5368 = torch.aten.clone %5367, %int0_6502 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5368, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_6503 = torch.constant.int 4 + %int32_6504 = torch.constant.int 32 + %int128_6505 = torch.constant.int 128 + %5369 = torch.prim.ListConstruct %int4_6503, %298, %int32_6504, %int128_6505 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5370 = torch.aten._unsafe_view %5368, %5369 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5370, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_6506 = torch.constant.int -2 + %5371 = torch.aten.unsqueeze %5194, %int-2_6506 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5371, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_6507 = torch.constant.int 4 + %int8_6508 = torch.constant.int 8 + %int4_6509 = torch.constant.int 4 + %int128_6510 = torch.constant.int 128 + %5372 = torch.prim.ListConstruct %int4_6507, %298, %int8_6508, %int4_6509, %int128_6510 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_6511 = torch.constant.bool false + %5373 = torch.aten.expand %5371, %5372, %false_6511 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5373, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_6512 = torch.constant.int 0 + %5374 = torch.aten.clone %5373, %int0_6512 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5374, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_6513 = torch.constant.int 4 + %int32_6514 = torch.constant.int 32 + %int128_6515 = torch.constant.int 128 + %5375 = torch.prim.ListConstruct %int4_6513, %298, %int32_6514, %int128_6515 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5376 = torch.aten._unsafe_view %5374, %5375 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5376, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_6516 = torch.constant.int 1 + %int2_6517 = torch.constant.int 2 + %5377 = torch.aten.transpose.int %5257, %int1_6516, %int2_6517 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5377, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_6518 = torch.constant.int 1 + %int2_6519 = torch.constant.int 2 + %5378 = torch.aten.transpose.int %5370, %int1_6518, %int2_6519 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5378, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_6520 = torch.constant.int 1 + %int2_6521 = torch.constant.int 2 + %5379 = torch.aten.transpose.int %5376, %int1_6520, %int2_6521 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5379, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_6522 = torch.constant.float 0.000000e+00 + %false_6523 = torch.constant.bool false + %none_6524 = torch.constant.none + %5380:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5377, %5378, %5379, %float0.000000e00_6522, %false_6523, %327, %none_6524) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %5380#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_6525 = torch.constant.int 1 + %int2_6526 = torch.constant.int 2 + %5381 = torch.aten.transpose.int %5380#0, %int1_6525, %int2_6526 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5381, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_6527 = torch.constant.int 4 + %int4096_6528 = torch.constant.int 4096 + %5382 = torch.prim.ListConstruct %int4_6527, %298, %int4096_6528 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5383 = torch.aten.view %5381, %5382 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5383, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_6529 = torch.constant.int -2 + %int-1_6530 = torch.constant.int -1 + %5384 = torch.aten.transpose.int %168, %int-2_6529, %int-1_6530 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_6531 = torch.constant.int 5 + %5385 = torch.prims.convert_element_type %5384, %int5_6531 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> %int4096_6532 = torch.constant.int 4096 - %5382 = torch.prim.ListConstruct %int4_6531, %5272, %int4096_6532 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5383 = torch.aten.view %5381, %5382 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5383, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_6533 = torch.constant.int 1 - %5384 = torch.aten.add.Tensor %5222, %5383, %int1_6533 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5384, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_6534 = torch.constant.int 6 - %5385 = torch.prims.convert_element_type %5384, %int6_6534 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5385, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_6535 = torch.constant.int 2 - %5386 = torch.aten.pow.Tensor_Scalar %5385, %int2_6535 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5386, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_6536 = torch.constant.int -1 - %5387 = torch.prim.ListConstruct %int-1_6536 : (!torch.int) -> !torch.list - %true_6537 = torch.constant.bool true - %none_6538 = torch.constant.none - %5388 = torch.aten.mean.dim %5386, %5387, %true_6537, %none_6538 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5388, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_6539 = torch.constant.float 9.9999997473787516E-6 - %int1_6540 = torch.constant.int 1 - %5389 = torch.aten.add.Scalar %5388, %float9.999990e-06_6539, %int1_6540 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5389, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5390 = torch.aten.rsqrt %5389 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5390, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5391 = torch.aten.mul.Tensor %5385, %5390 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5391, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6541 = torch.constant.int 5 - %5392 = torch.prims.convert_element_type %5391, %int5_6541 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5392, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %5393 = torch.aten.mul.Tensor %231, %5392 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5393, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6542 = torch.constant.int 5 - %5394 = torch.prims.convert_element_type %5393, %int5_6542 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5394, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6543 = torch.constant.int -2 - %int-1_6544 = torch.constant.int -1 - %5395 = torch.aten.transpose.int %232, %int-2_6543, %int-1_6544 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6545 = torch.constant.int 4 - %5396 = torch.aten.mul.int %int4_6545, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6546 = torch.constant.int 4096 - %5397 = torch.prim.ListConstruct %5396, %int4096_6546 : (!torch.int, !torch.int) -> !torch.list - %5398 = torch.aten.view %5394, %5397 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5398, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5399 = torch.aten.mm %5398, %5395 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5399, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_6547 = torch.constant.int 4 - %int14336_6548 = torch.constant.int 14336 - %5400 = torch.prim.ListConstruct %int4_6547, %306, %int14336_6548 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5401 = torch.aten.view %5399, %5400 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5401, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %5402 = torch.aten.silu %5401 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5402, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_6549 = torch.constant.int -2 - %int-1_6550 = torch.constant.int -1 - %5403 = torch.aten.transpose.int %233, %int-2_6549, %int-1_6550 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6551 = torch.constant.int 4 - %5404 = torch.aten.mul.int %int4_6551, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6552 = torch.constant.int 4096 - %5405 = torch.prim.ListConstruct %5404, %int4096_6552 : (!torch.int, !torch.int) -> !torch.list - %5406 = torch.aten.view %5394, %5405 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5406, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5407 = torch.aten.mm %5406, %5403 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5407, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_6553 = torch.constant.int 4 - %int14336_6554 = torch.constant.int 14336 - %5408 = torch.prim.ListConstruct %int4_6553, %306, %int14336_6554 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5409 = torch.aten.view %5407, %5408 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5409, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %5410 = torch.aten.mul.Tensor %5402, %5409 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5410, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_6555 = torch.constant.int -2 - %int-1_6556 = torch.constant.int -1 - %5411 = torch.aten.transpose.int %234, %int-2_6555, %int-1_6556 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_6557 = torch.constant.int 1 - %5412 = torch.aten.size.int %5401, %int1_6557 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_6558 = torch.constant.int 4 - %5413 = torch.aten.mul.int %int4_6558, %5412 : !torch.int, !torch.int -> !torch.int - %int14336_6559 = torch.constant.int 14336 - %5414 = torch.prim.ListConstruct %5413, %int14336_6559 : (!torch.int, !torch.int) -> !torch.list - %5415 = torch.aten.view %5410, %5414 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5415, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %5416 = torch.aten.mm %5415, %5411 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5416, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_6560 = torch.constant.int 4 - %int4096_6561 = torch.constant.int 4096 - %5417 = torch.prim.ListConstruct %int4_6560, %5412, %int4096_6561 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5418 = torch.aten.view %5416, %5417 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5418, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_6562 = torch.constant.int 1 - %5419 = torch.aten.add.Tensor %5384, %5418, %int1_6562 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5419, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_6563 = torch.constant.int 6 - %5420 = torch.prims.convert_element_type %5419, %int6_6563 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5420, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_6564 = torch.constant.int 2 - %5421 = torch.aten.pow.Tensor_Scalar %5420, %int2_6564 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5421, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_6565 = torch.constant.int -1 - %5422 = torch.prim.ListConstruct %int-1_6565 : (!torch.int) -> !torch.list - %true_6566 = torch.constant.bool true - %none_6567 = torch.constant.none - %5423 = torch.aten.mean.dim %5421, %5422, %true_6566, %none_6567 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5423, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_6568 = torch.constant.float 9.9999997473787516E-6 - %int1_6569 = torch.constant.int 1 - %5424 = torch.aten.add.Scalar %5423, %float9.999990e-06_6568, %int1_6569 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5424, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5425 = torch.aten.rsqrt %5424 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5425, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5426 = torch.aten.mul.Tensor %5420, %5425 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5426, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6570 = torch.constant.int 5 - %5427 = torch.prims.convert_element_type %5426, %int5_6570 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5427, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %5428 = torch.aten.mul.Tensor %235, %5427 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5428, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %5386 = torch.prim.ListConstruct %342, %int4096_6532 : (!torch.int, !torch.int) -> !torch.list + %5387 = torch.aten.view %5383, %5386 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5387, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5388 = torch.aten.mm %5387, %5385 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5388, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_6533 = torch.constant.int 4 + %int4096_6534 = torch.constant.int 4096 + %5389 = torch.prim.ListConstruct %int4_6533, %298, %int4096_6534 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5390 = torch.aten.view %5388, %5389 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5390, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_6535 = torch.constant.int 1 + %5391 = torch.aten.add.Tensor %5157, %5390, %int1_6535 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5391, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_6536 = torch.constant.int 6 + %5392 = torch.prims.convert_element_type %5391, %int6_6536 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5392, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_6537 = torch.constant.int 2 + %5393 = torch.aten.pow.Tensor_Scalar %5392, %int2_6537 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5393, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_6538 = torch.constant.int -1 + %5394 = torch.prim.ListConstruct %int-1_6538 : (!torch.int) -> !torch.list + %true_6539 = torch.constant.bool true + %none_6540 = torch.constant.none + %5395 = torch.aten.mean.dim %5393, %5394, %true_6539, %none_6540 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5395, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_6541 = torch.constant.float 9.9999997473787516E-6 + %int1_6542 = torch.constant.int 1 + %5396 = torch.aten.add.Scalar %5395, %float9.999990e-06_6541, %int1_6542 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5396, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5397 = torch.aten.rsqrt %5396 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5397, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5398 = torch.aten.mul.Tensor %5392, %5397 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5398, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_6543 = torch.constant.int 5 + %5399 = torch.prims.convert_element_type %5398, %int5_6543 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5399, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %5400 = torch.aten.mul.Tensor %169, %5399 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5400, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_6544 = torch.constant.int 5 + %5401 = torch.prims.convert_element_type %5400, %int5_6544 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5401, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_6545 = torch.constant.int -2 + %int-1_6546 = torch.constant.int -1 + %5402 = torch.aten.transpose.int %170, %int-2_6545, %int-1_6546 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_6547 = torch.constant.int 5 + %5403 = torch.prims.convert_element_type %5402, %int5_6547 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_6548 = torch.constant.int 4096 + %5404 = torch.prim.ListConstruct %342, %int4096_6548 : (!torch.int, !torch.int) -> !torch.list + %5405 = torch.aten.view %5401, %5404 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5405, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5406 = torch.aten.mm %5405, %5403 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %5406, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_6549 = torch.constant.int 4 + %int14336_6550 = torch.constant.int 14336 + %5407 = torch.prim.ListConstruct %int4_6549, %298, %int14336_6550 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5408 = torch.aten.view %5406, %5407 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5408, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %5409 = torch.aten.silu %5408 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5409, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_6551 = torch.constant.int -2 + %int-1_6552 = torch.constant.int -1 + %5410 = torch.aten.transpose.int %171, %int-2_6551, %int-1_6552 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_6553 = torch.constant.int 5 + %5411 = torch.prims.convert_element_type %5410, %int5_6553 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_6554 = torch.constant.int 4096 + %5412 = torch.prim.ListConstruct %342, %int4096_6554 : (!torch.int, !torch.int) -> !torch.list + %5413 = torch.aten.view %5401, %5412 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5413, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5414 = torch.aten.mm %5413, %5411 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %5414, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_6555 = torch.constant.int 4 + %int14336_6556 = torch.constant.int 14336 + %5415 = torch.prim.ListConstruct %int4_6555, %298, %int14336_6556 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5416 = torch.aten.view %5414, %5415 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5416, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %5417 = torch.aten.mul.Tensor %5409, %5416 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5417, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_6557 = torch.constant.int -2 + %int-1_6558 = torch.constant.int -1 + %5418 = torch.aten.transpose.int %172, %int-2_6557, %int-1_6558 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_6559 = torch.constant.int 5 + %5419 = torch.prims.convert_element_type %5418, %int5_6559 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_6560 = torch.constant.int 14336 + %5420 = torch.prim.ListConstruct %342, %int14336_6560 : (!torch.int, !torch.int) -> !torch.list + %5421 = torch.aten.view %5417, %5420 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %5421, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %5422 = torch.aten.mm %5421, %5419 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5422, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_6561 = torch.constant.int 4 + %int4096_6562 = torch.constant.int 4096 + %5423 = torch.prim.ListConstruct %int4_6561, %298, %int4096_6562 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5424 = torch.aten.view %5422, %5423 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5424, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_6563 = torch.constant.int 1 + %5425 = torch.aten.add.Tensor %5391, %5424, %int1_6563 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5425, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_6564 = torch.constant.int 6 + %5426 = torch.prims.convert_element_type %5425, %int6_6564 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5426, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_6565 = torch.constant.int 2 + %5427 = torch.aten.pow.Tensor_Scalar %5426, %int2_6565 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5427, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_6566 = torch.constant.int -1 + %5428 = torch.prim.ListConstruct %int-1_6566 : (!torch.int) -> !torch.list + %true_6567 = torch.constant.bool true + %none_6568 = torch.constant.none + %5429 = torch.aten.mean.dim %5427, %5428, %true_6567, %none_6568 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5429, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_6569 = torch.constant.float 9.9999997473787516E-6 + %int1_6570 = torch.constant.int 1 + %5430 = torch.aten.add.Scalar %5429, %float9.999990e-06_6569, %int1_6570 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5430, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5431 = torch.aten.rsqrt %5430 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5431, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5432 = torch.aten.mul.Tensor %5426, %5431 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5432, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> %int5_6571 = torch.constant.int 5 - %5429 = torch.prims.convert_element_type %5428, %int5_6571 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5429, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6572 = torch.constant.int -2 - %int-1_6573 = torch.constant.int -1 - %5430 = torch.aten.transpose.int %236, %int-2_6572, %int-1_6573 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6574 = torch.constant.int 4 - %5431 = torch.aten.mul.int %int4_6574, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6575 = torch.constant.int 4096 - %5432 = torch.prim.ListConstruct %5431, %int4096_6575 : (!torch.int, !torch.int) -> !torch.list - %5433 = torch.aten.view %5429, %5432 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5433, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5434 = torch.aten.mm %5433, %5430 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5434, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_6576 = torch.constant.int 4 - %int4096_6577 = torch.constant.int 4096 - %5435 = torch.prim.ListConstruct %int4_6576, %306, %int4096_6577 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5436 = torch.aten.view %5434, %5435 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5436, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6578 = torch.constant.int -2 - %int-1_6579 = torch.constant.int -1 - %5437 = torch.aten.transpose.int %237, %int-2_6578, %int-1_6579 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6580 = torch.constant.int 4 - %5438 = torch.aten.mul.int %int4_6580, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6581 = torch.constant.int 4096 - %5439 = torch.prim.ListConstruct %5438, %int4096_6581 : (!torch.int, !torch.int) -> !torch.list - %5440 = torch.aten.view %5429, %5439 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5440, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5441 = torch.aten.mm %5440, %5437 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %5441, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_6582 = torch.constant.int 4 - %int1024_6583 = torch.constant.int 1024 - %5442 = torch.prim.ListConstruct %int4_6582, %306, %int1024_6583 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5443 = torch.aten.view %5441, %5442 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %5443, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_6584 = torch.constant.int -2 - %int-1_6585 = torch.constant.int -1 - %5444 = torch.aten.transpose.int %238, %int-2_6584, %int-1_6585 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6586 = torch.constant.int 4 - %5445 = torch.aten.mul.int %int4_6586, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6587 = torch.constant.int 4096 - %5446 = torch.prim.ListConstruct %5445, %int4096_6587 : (!torch.int, !torch.int) -> !torch.list - %5447 = torch.aten.view %5429, %5446 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5447, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5448 = torch.aten.mm %5447, %5444 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %5448, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_6588 = torch.constant.int 4 - %int1024_6589 = torch.constant.int 1024 - %5449 = torch.prim.ListConstruct %int4_6588, %306, %int1024_6589 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5450 = torch.aten.view %5448, %5449 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %5450, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_6590 = torch.constant.int 4 - %int32_6591 = torch.constant.int 32 - %int128_6592 = torch.constant.int 128 - %5451 = torch.prim.ListConstruct %int4_6590, %306, %int32_6591, %int128_6592 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5452 = torch.aten.view %5436, %5451 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5452, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_6593 = torch.constant.int 4 - %int8_6594 = torch.constant.int 8 - %int128_6595 = torch.constant.int 128 - %5453 = torch.prim.ListConstruct %int4_6593, %306, %int8_6594, %int128_6595 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5454 = torch.aten.view %5443, %5453 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5454, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_6596 = torch.constant.int 4 - %int8_6597 = torch.constant.int 8 - %int128_6598 = torch.constant.int 128 - %5455 = torch.prim.ListConstruct %int4_6596, %306, %int8_6597, %int128_6598 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5456 = torch.aten.view %5450, %5455 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5456, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_6599 = torch.constant.int 131072 - %none_6600 = torch.constant.none + %5433 = torch.prims.convert_element_type %5432, %int5_6571 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5433, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %5434 = torch.aten.mul.Tensor %173, %5433 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5434, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_6572 = torch.constant.int 5 + %5435 = torch.prims.convert_element_type %5434, %int5_6572 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5435, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_6573 = torch.constant.int -2 + %int-1_6574 = torch.constant.int -1 + %5436 = torch.aten.transpose.int %174, %int-2_6573, %int-1_6574 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_6575 = torch.constant.int 5 + %5437 = torch.prims.convert_element_type %5436, %int5_6575 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_6576 = torch.constant.int 4096 + %5438 = torch.prim.ListConstruct %342, %int4096_6576 : (!torch.int, !torch.int) -> !torch.list + %5439 = torch.aten.view %5435, %5438 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5439, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5440 = torch.aten.mm %5439, %5437 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5440, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_6577 = torch.constant.int 4 + %int4096_6578 = torch.constant.int 4096 + %5441 = torch.prim.ListConstruct %int4_6577, %298, %int4096_6578 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5442 = torch.aten.view %5440, %5441 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5442, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_6579 = torch.constant.int -2 + %int-1_6580 = torch.constant.int -1 + %5443 = torch.aten.transpose.int %175, %int-2_6579, %int-1_6580 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6581 = torch.constant.int 5 + %5444 = torch.prims.convert_element_type %5443, %int5_6581 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_6582 = torch.constant.int 4096 + %5445 = torch.prim.ListConstruct %342, %int4096_6582 : (!torch.int, !torch.int) -> !torch.list + %5446 = torch.aten.view %5435, %5445 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5446, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5447 = torch.aten.mm %5446, %5444 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %5447, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_6583 = torch.constant.int 4 + %int1024_6584 = torch.constant.int 1024 + %5448 = torch.prim.ListConstruct %int4_6583, %298, %int1024_6584 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5449 = torch.aten.view %5447, %5448 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %5449, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_6585 = torch.constant.int -2 + %int-1_6586 = torch.constant.int -1 + %5450 = torch.aten.transpose.int %176, %int-2_6585, %int-1_6586 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6587 = torch.constant.int 5 + %5451 = torch.prims.convert_element_type %5450, %int5_6587 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_6588 = torch.constant.int 4096 + %5452 = torch.prim.ListConstruct %342, %int4096_6588 : (!torch.int, !torch.int) -> !torch.list + %5453 = torch.aten.view %5435, %5452 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5453, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5454 = torch.aten.mm %5453, %5451 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %5454, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_6589 = torch.constant.int 4 + %int1024_6590 = torch.constant.int 1024 + %5455 = torch.prim.ListConstruct %int4_6589, %298, %int1024_6590 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5456 = torch.aten.view %5454, %5455 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %5456, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_6591 = torch.constant.int 4 + %int32_6592 = torch.constant.int 32 + %int128_6593 = torch.constant.int 128 + %5457 = torch.prim.ListConstruct %int4_6591, %298, %int32_6592, %int128_6593 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5458 = torch.aten.view %5442, %5457 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5458, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_6594 = torch.constant.int 4 + %int8_6595 = torch.constant.int 8 + %int128_6596 = torch.constant.int 128 + %5459 = torch.prim.ListConstruct %int4_6594, %298, %int8_6595, %int128_6596 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5460 = torch.aten.view %5449, %5459 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5460, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_6597 = torch.constant.int 4 + %int8_6598 = torch.constant.int 8 + %int128_6599 = torch.constant.int 128 + %5461 = torch.prim.ListConstruct %int4_6597, %298, %int8_6598, %int128_6599 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5462 = torch.aten.view %5456, %5461 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5462, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_6600 = torch.constant.int 131072 %none_6601 = torch.constant.none - %cpu_6602 = torch.constant.device "cpu" - %false_6603 = torch.constant.bool false - %5457 = torch.aten.arange %int131072_6599, %none_6600, %none_6601, %cpu_6602, %false_6603 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_6604 = torch.constant.int 0 - %int128_6605 = torch.constant.int 128 - %none_6606 = torch.constant.none - %none_6607 = torch.constant.none - %cpu_6608 = torch.constant.device "cpu" - %false_6609 = torch.constant.bool false - %5458 = torch.aten.arange.start %int0_6604, %int128_6605, %none_6606, %none_6607, %cpu_6608, %false_6609 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_6610 = torch.constant.int 2 - %5459 = torch.aten.floor_divide.Scalar %5458, %int2_6610 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_6611 = torch.constant.int 6 - %5460 = torch.prims.convert_element_type %5459, %int6_6611 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_6612 = torch.constant.int 128 - %5461 = torch.aten.div.Scalar %5460, %int128_6612 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_6613 = torch.constant.float 2.000000e+00 - %5462 = torch.aten.mul.Scalar %5461, %float2.000000e00_6613 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %none_6602 = torch.constant.none + %cpu_6603 = torch.constant.device "cpu" + %false_6604 = torch.constant.bool false + %5463 = torch.aten.arange %int131072_6600, %none_6601, %none_6602, %cpu_6603, %false_6604 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_6605 = torch.constant.int 0 + %int128_6606 = torch.constant.int 128 + %int2_6607 = torch.constant.int 2 + %int4_6608 = torch.constant.int 4 + %none_6609 = torch.constant.none + %cpu_6610 = torch.constant.device "cpu" + %false_6611 = torch.constant.bool false + %5464 = torch.aten.arange.start_step %int0_6605, %int128_6606, %int2_6607, %int4_6608, %none_6609, %cpu_6610, %false_6611 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_6612 = torch.constant.int 6 + %5465 = torch.prims.convert_element_type %5464, %int6_6612 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_6613 = torch.constant.int 128 + %5466 = torch.aten.div.Scalar %5465, %int128_6613 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %float5.000000e05_6614 = torch.constant.float 5.000000e+05 - %5463 = torch.aten.pow.Scalar %float5.000000e05_6614, %5462 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %5464 = torch.aten.reciprocal %5463 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> + %5467 = torch.aten.pow.Scalar %float5.000000e05_6614, %5466 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5468 = torch.aten.reciprocal %5467 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> %float1.000000e00_6615 = torch.constant.float 1.000000e+00 - %5465 = torch.aten.mul.Scalar %5464, %float1.000000e00_6615 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_6616 = torch.constant.int 1 - %5466 = torch.aten.unsqueeze %5457, %int1_6616 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_6617 = torch.constant.int 0 - %5467 = torch.aten.unsqueeze %5465, %int0_6617 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %5468 = torch.aten.mul.Tensor %5466, %5467 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_6618 = torch.constant.int 1 - %5469 = torch.aten.size.int %5436, %int1_6618 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_6619 = torch.constant.int 0 - %5470 = torch.aten.add.int %int0_6619, %5469 : !torch.int, !torch.int -> !torch.int - %int0_6620 = torch.constant.int 0 - %int0_6621 = torch.constant.int 0 - %int1_6622 = torch.constant.int 1 - %5471 = torch.aten.slice.Tensor %5468, %int0_6620, %int0_6621, %5470, %int1_6622 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5471, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %5469 = torch.aten.mul.Scalar %5468, %float1.000000e00_6615 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %5470 = torch.aten.reciprocal %5469 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_6616 = torch.constant.float 6.2831853071795862 + %5471 = torch.aten.mul.Scalar %5470, %float6.283190e00_6616 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_6617 = torch.constant.float 8.192000e+03 + %5472 = torch.aten.gt.Scalar %5471, %float8.192000e03_6617 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_6618 = torch.constant.int 8 + %5473 = torch.aten.div.Scalar %5469, %int8_6618 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %5474 = torch.aten.where.self %5472, %5473, %5469 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5475 = torch.aten.reciprocal %5471 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_6619 = torch.constant.int 8192 + %5476 = torch.aten.mul.Scalar %5475, %int8192_6619 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_6620 = torch.constant.int 1 + %int1_6621 = torch.constant.int 1 + %5477 = torch.aten.sub.Scalar %5476, %int1_6620, %int1_6621 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_6622 = torch.constant.int 3 + %5478 = torch.aten.div.Scalar %5477, %int3_6622 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_6623 = torch.constant.int 1 - %int0_6624 = torch.constant.int 0 - %int9223372036854775807_6625 = torch.constant.int 9223372036854775807 + %int1_6624 = torch.constant.int 1 + %5479 = torch.aten.rsub.Scalar %5478, %int1_6623, %int1_6624 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %5480 = torch.aten.mul.Tensor %5479, %5474 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_6625 = torch.constant.int 8 + %5481 = torch.aten.div.Scalar %5480, %int8_6625 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %5482 = torch.aten.mul.Tensor %5478, %5474 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> %int1_6626 = torch.constant.int 1 - %5472 = torch.aten.slice.Tensor %5471, %int1_6623, %int0_6624, %int9223372036854775807_6625, %int1_6626 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5472, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_6627 = torch.constant.int 1 - %int0_6628 = torch.constant.int 0 - %int9223372036854775807_6629 = torch.constant.int 9223372036854775807 - %int1_6630 = torch.constant.int 1 - %5473 = torch.aten.slice.Tensor %5472, %int1_6627, %int0_6628, %int9223372036854775807_6629, %int1_6630 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5473, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_6631 = torch.constant.int 0 - %5474 = torch.aten.unsqueeze %5473, %int0_6631 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5474, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_6632 = torch.constant.int 1 + %5483 = torch.aten.add.Tensor %5481, %5482, %int1_6626 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_6627 = torch.constant.float 2.048000e+03 + %5484 = torch.aten.lt.Scalar %5471, %float2.048000e03_6627 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %5485 = torch.aten.bitwise_not %5484 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_6628 = torch.constant.float 8.192000e+03 + %5486 = torch.aten.gt.Scalar %5471, %float8.192000e03_6628 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %5487 = torch.aten.bitwise_not %5486 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %5488 = torch.aten.mul.Tensor %5485, %5487 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %5489 = torch.aten.where.self %5488, %5483, %5474 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5490 = torch.prim.ListConstruct %5489, %5489 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_6629 = torch.constant.int -1 + %5491 = torch.aten.cat %5490, %int-1_6629 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_6630 = torch.constant.int 6 + %5492 = torch.prims.convert_element_type %5491, %int6_6630 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_6631 = torch.constant.int 1 + %5493 = torch.aten.unsqueeze %5463, %int1_6631 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_6632 = torch.constant.int 6 + %5494 = torch.prims.convert_element_type %5493, %int6_6632 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> %int0_6633 = torch.constant.int 0 - %int9223372036854775807_6634 = torch.constant.int 9223372036854775807 - %int1_6635 = torch.constant.int 1 - %5475 = torch.aten.slice.Tensor %5474, %int1_6632, %int0_6633, %int9223372036854775807_6634, %int1_6635 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5475, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_6636 = torch.constant.int 2 + %5495 = torch.aten.unsqueeze %5492, %int0_6633 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_6634 = torch.constant.int 6 + %5496 = torch.prims.convert_element_type %5495, %int6_6634 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %5497 = torch.aten.mul.Tensor %5494, %5496 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %5498 = torch.aten.cos %5497 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_6635 = torch.constant.int 5 + %5499 = torch.prims.convert_element_type %5498, %int5_6635 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %5500 = torch.aten.sin %5497 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_6636 = torch.constant.int 5 + %5501 = torch.prims.convert_element_type %5500, %int5_6636 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> %int0_6637 = torch.constant.int 0 - %int9223372036854775807_6638 = torch.constant.int 9223372036854775807 + %int0_6638 = torch.constant.int 0 %int1_6639 = torch.constant.int 1 - %5476 = torch.aten.slice.Tensor %5475, %int2_6636, %int0_6637, %int9223372036854775807_6638, %int1_6639 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5476, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_6640 = torch.constant.int 4 - %int1_6641 = torch.constant.int 1 - %int1_6642 = torch.constant.int 1 - %5477 = torch.prim.ListConstruct %int4_6640, %int1_6641, %int1_6642 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5478 = torch.aten.repeat %5476, %5477 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %5478, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_6643 = torch.constant.int 6 - %5479 = torch.prims.convert_element_type %5452, %int6_6643 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %5479, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %5480 = torch_c.to_builtin_tensor %5479 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %5481 = torch_c.to_builtin_tensor %5478 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %5482 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%5480, %5481) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %5483 = torch_c.from_builtin_tensor %5482 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %5483, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_6644 = torch.constant.int 5 - %5484 = torch.prims.convert_element_type %5483, %int5_6644 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5484, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_6645 = torch.constant.int 131072 - %none_6646 = torch.constant.none - %none_6647 = torch.constant.none - %cpu_6648 = torch.constant.device "cpu" - %false_6649 = torch.constant.bool false - %5485 = torch.aten.arange %int131072_6645, %none_6646, %none_6647, %cpu_6648, %false_6649 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_6650 = torch.constant.int 0 - %int128_6651 = torch.constant.int 128 - %none_6652 = torch.constant.none - %none_6653 = torch.constant.none - %cpu_6654 = torch.constant.device "cpu" - %false_6655 = torch.constant.bool false - %5486 = torch.aten.arange.start %int0_6650, %int128_6651, %none_6652, %none_6653, %cpu_6654, %false_6655 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> + %5502 = torch.aten.slice.Tensor %5499, %int0_6637, %int0_6638, %298, %int1_6639 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5502, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_6640 = torch.constant.int 1 + %int0_6641 = torch.constant.int 0 + %int9223372036854775807_6642 = torch.constant.int 9223372036854775807 + %int1_6643 = torch.constant.int 1 + %5503 = torch.aten.slice.Tensor %5502, %int1_6640, %int0_6641, %int9223372036854775807_6642, %int1_6643 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5503, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_6644 = torch.constant.int 0 + %int0_6645 = torch.constant.int 0 + %int1_6646 = torch.constant.int 1 + %5504 = torch.aten.slice.Tensor %5501, %int0_6644, %int0_6645, %298, %int1_6646 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5504, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_6647 = torch.constant.int 1 + %int0_6648 = torch.constant.int 0 + %int9223372036854775807_6649 = torch.constant.int 9223372036854775807 + %int1_6650 = torch.constant.int 1 + %5505 = torch.aten.slice.Tensor %5504, %int1_6647, %int0_6648, %int9223372036854775807_6649, %int1_6650 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5505, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_6651 = torch.constant.int 0 + %5506 = torch.aten.unsqueeze %5503, %int0_6651 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5506, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_6652 = torch.constant.int 1 + %int0_6653 = torch.constant.int 0 + %int9223372036854775807_6654 = torch.constant.int 9223372036854775807 + %int1_6655 = torch.constant.int 1 + %5507 = torch.aten.slice.Tensor %5506, %int1_6652, %int0_6653, %int9223372036854775807_6654, %int1_6655 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5507, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int2_6656 = torch.constant.int 2 - %5487 = torch.aten.floor_divide.Scalar %5486, %int2_6656 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_6657 = torch.constant.int 6 - %5488 = torch.prims.convert_element_type %5487, %int6_6657 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_6658 = torch.constant.int 128 - %5489 = torch.aten.div.Scalar %5488, %int128_6658 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_6659 = torch.constant.float 2.000000e+00 - %5490 = torch.aten.mul.Scalar %5489, %float2.000000e00_6659 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_6660 = torch.constant.float 5.000000e+05 - %5491 = torch.aten.pow.Scalar %float5.000000e05_6660, %5490 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %5492 = torch.aten.reciprocal %5491 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_6661 = torch.constant.float 1.000000e+00 - %5493 = torch.aten.mul.Scalar %5492, %float1.000000e00_6661 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %5508 = torch.aten.unsqueeze %5507, %int2_6656 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5508, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_6657 = torch.constant.int 3 + %int0_6658 = torch.constant.int 0 + %int9223372036854775807_6659 = torch.constant.int 9223372036854775807 + %int1_6660 = torch.constant.int 1 + %5509 = torch.aten.slice.Tensor %5508, %int3_6657, %int0_6658, %int9223372036854775807_6659, %int1_6660 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5509, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_6661 = torch.constant.int 4 %int1_6662 = torch.constant.int 1 - %5494 = torch.aten.unsqueeze %5485, %int1_6662 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_6663 = torch.constant.int 0 - %5495 = torch.aten.unsqueeze %5493, %int0_6663 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %5496 = torch.aten.mul.Tensor %5494, %5495 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %int1_6663 = torch.constant.int 1 %int1_6664 = torch.constant.int 1 - %5497 = torch.aten.size.int %5443, %int1_6664 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int + %5510 = torch.prim.ListConstruct %int4_6661, %int1_6662, %int1_6663, %int1_6664 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5511 = torch.aten.repeat %5509, %5510 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %5511, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> %int0_6665 = torch.constant.int 0 - %5498 = torch.aten.add.int %int0_6665, %5497 : !torch.int, !torch.int -> !torch.int - %int0_6666 = torch.constant.int 0 + %5512 = torch.aten.unsqueeze %5505, %int0_6665 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5512, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_6666 = torch.constant.int 1 %int0_6667 = torch.constant.int 0 - %int1_6668 = torch.constant.int 1 - %5499 = torch.aten.slice.Tensor %5496, %int0_6666, %int0_6667, %5498, %int1_6668 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5499, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %int9223372036854775807_6668 = torch.constant.int 9223372036854775807 %int1_6669 = torch.constant.int 1 - %int0_6670 = torch.constant.int 0 - %int9223372036854775807_6671 = torch.constant.int 9223372036854775807 - %int1_6672 = torch.constant.int 1 - %5500 = torch.aten.slice.Tensor %5499, %int1_6669, %int0_6670, %int9223372036854775807_6671, %int1_6672 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5500, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_6673 = torch.constant.int 1 - %int0_6674 = torch.constant.int 0 - %int9223372036854775807_6675 = torch.constant.int 9223372036854775807 + %5513 = torch.aten.slice.Tensor %5512, %int1_6666, %int0_6667, %int9223372036854775807_6668, %int1_6669 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5513, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_6670 = torch.constant.int 2 + %5514 = torch.aten.unsqueeze %5513, %int2_6670 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5514, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_6671 = torch.constant.int 3 + %int0_6672 = torch.constant.int 0 + %int9223372036854775807_6673 = torch.constant.int 9223372036854775807 + %int1_6674 = torch.constant.int 1 + %5515 = torch.aten.slice.Tensor %5514, %int3_6671, %int0_6672, %int9223372036854775807_6673, %int1_6674 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5515, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_6675 = torch.constant.int 4 %int1_6676 = torch.constant.int 1 - %5501 = torch.aten.slice.Tensor %5500, %int1_6673, %int0_6674, %int9223372036854775807_6675, %int1_6676 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5501, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_6677 = torch.constant.int 0 - %5502 = torch.aten.unsqueeze %5501, %int0_6677 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5502, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %int1_6677 = torch.constant.int 1 %int1_6678 = torch.constant.int 1 - %int0_6679 = torch.constant.int 0 - %int9223372036854775807_6680 = torch.constant.int 9223372036854775807 - %int1_6681 = torch.constant.int 1 - %5503 = torch.aten.slice.Tensor %5502, %int1_6678, %int0_6679, %int9223372036854775807_6680, %int1_6681 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5503, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_6682 = torch.constant.int 2 - %int0_6683 = torch.constant.int 0 - %int9223372036854775807_6684 = torch.constant.int 9223372036854775807 - %int1_6685 = torch.constant.int 1 - %5504 = torch.aten.slice.Tensor %5503, %int2_6682, %int0_6683, %int9223372036854775807_6684, %int1_6685 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5504, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_6686 = torch.constant.int 4 - %int1_6687 = torch.constant.int 1 + %5516 = torch.prim.ListConstruct %int4_6675, %int1_6676, %int1_6677, %int1_6678 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5517 = torch.aten.repeat %5515, %5516 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %5517, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %5518 = torch.aten.mul.Tensor %5458, %5511 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5518, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_6679 = torch.constant.int 3 + %int0_6680 = torch.constant.int 0 + %int64_6681 = torch.constant.int 64 + %int1_6682 = torch.constant.int 1 + %5519 = torch.aten.slice.Tensor %5458, %int3_6679, %int0_6680, %int64_6681, %int1_6682 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %5519, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_6683 = torch.constant.int 3 + %int64_6684 = torch.constant.int 64 + %int9223372036854775807_6685 = torch.constant.int 9223372036854775807 + %int1_6686 = torch.constant.int 1 + %5520 = torch.aten.slice.Tensor %5458, %int3_6683, %int64_6684, %int9223372036854775807_6685, %int1_6686 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %5520, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %5521 = torch.aten.neg %5520 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %5521, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %5522 = torch.prim.ListConstruct %5521, %5519 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_6687 = torch.constant.int -1 + %5523 = torch.aten.cat %5522, %int-1_6687 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5523, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %5524 = torch.aten.mul.Tensor %5523, %5517 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5524, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int1_6688 = torch.constant.int 1 - %5505 = torch.prim.ListConstruct %int4_6686, %int1_6687, %int1_6688 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5506 = torch.aten.repeat %5504, %5505 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %5506, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_6689 = torch.constant.int 6 - %5507 = torch.prims.convert_element_type %5454, %int6_6689 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %5507, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %5508 = torch_c.to_builtin_tensor %5507 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %5509 = torch_c.to_builtin_tensor %5506 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %5510 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%5508, %5509) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %5511 = torch_c.from_builtin_tensor %5510 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %5511, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_6690 = torch.constant.int 5 - %5512 = torch.prims.convert_element_type %5511, %int5_6690 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5512, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_6691 = torch.constant.int 64 - %5513 = torch.aten.mul.Scalar %arg2, %int64_6691 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5513, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int52 = torch.constant.int 52 - %int1_6692 = torch.constant.int 1 - %5514 = torch.aten.add.Scalar %5513, %int52, %int1_6692 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5514, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_6693 = torch.constant.int 4 - %int32_6694 = torch.constant.int 32 - %int8_6695 = torch.constant.int 8 - %int128_6696 = torch.constant.int 128 - %5515 = torch.prim.ListConstruct %int4_6693, %398, %int32_6694, %int8_6695, %int128_6696 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5516 = torch.aten.view %5512, %5515 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5516, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %5525 = torch.aten.add.Tensor %5518, %5524, %int1_6688 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5525, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_6689 = torch.constant.int 131072 + %none_6690 = torch.constant.none + %none_6691 = torch.constant.none + %cpu_6692 = torch.constant.device "cpu" + %false_6693 = torch.constant.bool false + %5526 = torch.aten.arange %int131072_6689, %none_6690, %none_6691, %cpu_6692, %false_6693 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_6694 = torch.constant.int 0 + %int128_6695 = torch.constant.int 128 + %int2_6696 = torch.constant.int 2 %int4_6697 = torch.constant.int 4 - %5517 = torch.aten.mul.int %int4_6697, %398 : !torch.int, !torch.int -> !torch.int - %int32_6698 = torch.constant.int 32 - %int8_6699 = torch.constant.int 8 - %int128_6700 = torch.constant.int 128 - %5518 = torch.prim.ListConstruct %5517, %int32_6698, %int8_6699, %int128_6700 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5519 = torch.aten.view %5516, %5518 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5519, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_6701 = torch.constant.int 4 - %5520 = torch.aten.mul.int %int4_6701, %398 : !torch.int, !torch.int -> !torch.int - %5521 = torch.prim.ListConstruct %5520 : (!torch.int) -> !torch.list - %5522 = torch.aten.view %5514, %5521 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %5522, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_6702 = torch.constant.int 32 - %int2_6703 = torch.constant.int 2 - %int32_6704 = torch.constant.int 32 - %int8_6705 = torch.constant.int 8 - %int128_6706 = torch.constant.int 128 - %5523 = torch.prim.ListConstruct %389, %int32_6702, %int2_6703, %int32_6704, %int8_6705, %int128_6706 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5524 = torch.aten.view %5356, %5523 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5524, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_6707 = torch.constant.int 32 - %5525 = torch.aten.mul.int %389, %int32_6707 : !torch.int, !torch.int -> !torch.int - %int2_6708 = torch.constant.int 2 - %5526 = torch.aten.mul.int %5525, %int2_6708 : !torch.int, !torch.int -> !torch.int - %int32_6709 = torch.constant.int 32 - %int8_6710 = torch.constant.int 8 - %int128_6711 = torch.constant.int 128 - %5527 = torch.prim.ListConstruct %5526, %int32_6709, %int8_6710, %int128_6711 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5528 = torch.aten.view %5524, %5527 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5528, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %5529 = torch.prim.ListConstruct %5522 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_6712 = torch.constant.bool false - %5530 = torch.aten.index_put %5528, %5529, %5519, %false_6712 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5530, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_6713 = torch.constant.int 32 - %int2_6714 = torch.constant.int 2 - %int32_6715 = torch.constant.int 32 - %int8_6716 = torch.constant.int 8 - %int128_6717 = torch.constant.int 128 - %5531 = torch.prim.ListConstruct %389, %int32_6713, %int2_6714, %int32_6715, %int8_6716, %int128_6717 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5532 = torch.aten.view %5530, %5531 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5532, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6718 = torch.constant.int 2097152 - %5533 = torch.prim.ListConstruct %389, %int2097152_6718 : (!torch.int, !torch.int) -> !torch.list - %5534 = torch.aten.view %5532, %5533 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5534, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_6719 = torch.constant.int 32 - %int2_6720 = torch.constant.int 2 - %int32_6721 = torch.constant.int 32 - %int8_6722 = torch.constant.int 8 - %int128_6723 = torch.constant.int 128 - %5535 = torch.prim.ListConstruct %389, %int32_6719, %int2_6720, %int32_6721, %int8_6722, %int128_6723 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5536 = torch.aten.view %5534, %5535 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5536, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_6724 = torch.constant.int 32 - %int8_6725 = torch.constant.int 8 - %int128_6726 = torch.constant.int 128 - %5537 = torch.prim.ListConstruct %5526, %int32_6724, %int8_6725, %int128_6726 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5538 = torch.aten.view %5536, %5537 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5538, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_6727 = torch.constant.int 4 - %int32_6728 = torch.constant.int 32 - %int8_6729 = torch.constant.int 8 - %int128_6730 = torch.constant.int 128 - %5539 = torch.prim.ListConstruct %int4_6727, %398, %int32_6728, %int8_6729, %int128_6730 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5540 = torch.aten.view %5456, %5539 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5540, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_6731 = torch.constant.int 4 - %5541 = torch.aten.mul.int %int4_6731, %398 : !torch.int, !torch.int -> !torch.int - %int32_6732 = torch.constant.int 32 - %int8_6733 = torch.constant.int 8 - %int128_6734 = torch.constant.int 128 - %5542 = torch.prim.ListConstruct %5541, %int32_6732, %int8_6733, %int128_6734 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5543 = torch.aten.view %5540, %5542 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5543, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %none_6698 = torch.constant.none + %cpu_6699 = torch.constant.device "cpu" + %false_6700 = torch.constant.bool false + %5527 = torch.aten.arange.start_step %int0_6694, %int128_6695, %int2_6696, %int4_6697, %none_6698, %cpu_6699, %false_6700 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_6701 = torch.constant.int 6 + %5528 = torch.prims.convert_element_type %5527, %int6_6701 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_6702 = torch.constant.int 128 + %5529 = torch.aten.div.Scalar %5528, %int128_6702 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_6703 = torch.constant.float 5.000000e+05 + %5530 = torch.aten.pow.Scalar %float5.000000e05_6703, %5529 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5531 = torch.aten.reciprocal %5530 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_6704 = torch.constant.float 1.000000e+00 + %5532 = torch.aten.mul.Scalar %5531, %float1.000000e00_6704 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %5533 = torch.aten.reciprocal %5532 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_6705 = torch.constant.float 6.2831853071795862 + %5534 = torch.aten.mul.Scalar %5533, %float6.283190e00_6705 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_6706 = torch.constant.float 8.192000e+03 + %5535 = torch.aten.gt.Scalar %5534, %float8.192000e03_6706 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_6707 = torch.constant.int 8 + %5536 = torch.aten.div.Scalar %5532, %int8_6707 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %5537 = torch.aten.where.self %5535, %5536, %5532 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5538 = torch.aten.reciprocal %5534 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_6708 = torch.constant.int 8192 + %5539 = torch.aten.mul.Scalar %5538, %int8192_6708 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_6709 = torch.constant.int 1 + %int1_6710 = torch.constant.int 1 + %5540 = torch.aten.sub.Scalar %5539, %int1_6709, %int1_6710 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_6711 = torch.constant.int 3 + %5541 = torch.aten.div.Scalar %5540, %int3_6711 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_6712 = torch.constant.int 1 + %int1_6713 = torch.constant.int 1 + %5542 = torch.aten.rsub.Scalar %5541, %int1_6712, %int1_6713 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %5543 = torch.aten.mul.Tensor %5542, %5537 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_6714 = torch.constant.int 8 + %5544 = torch.aten.div.Scalar %5543, %int8_6714 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %5545 = torch.aten.mul.Tensor %5541, %5537 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_6715 = torch.constant.int 1 + %5546 = torch.aten.add.Tensor %5544, %5545, %int1_6715 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_6716 = torch.constant.float 2.048000e+03 + %5547 = torch.aten.lt.Scalar %5534, %float2.048000e03_6716 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %5548 = torch.aten.bitwise_not %5547 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_6717 = torch.constant.float 8.192000e+03 + %5549 = torch.aten.gt.Scalar %5534, %float8.192000e03_6717 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %5550 = torch.aten.bitwise_not %5549 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %5551 = torch.aten.mul.Tensor %5548, %5550 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %5552 = torch.aten.where.self %5551, %5546, %5537 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5553 = torch.prim.ListConstruct %5552, %5552 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_6718 = torch.constant.int -1 + %5554 = torch.aten.cat %5553, %int-1_6718 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_6719 = torch.constant.int 6 + %5555 = torch.prims.convert_element_type %5554, %int6_6719 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_6720 = torch.constant.int 1 + %5556 = torch.aten.unsqueeze %5526, %int1_6720 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_6721 = torch.constant.int 6 + %5557 = torch.prims.convert_element_type %5556, %int6_6721 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_6722 = torch.constant.int 0 + %5558 = torch.aten.unsqueeze %5555, %int0_6722 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_6723 = torch.constant.int 6 + %5559 = torch.prims.convert_element_type %5558, %int6_6723 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %5560 = torch.aten.mul.Tensor %5557, %5559 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %5561 = torch.aten.cos %5560 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_6724 = torch.constant.int 5 + %5562 = torch.prims.convert_element_type %5561, %int5_6724 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %5563 = torch.aten.sin %5560 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_6725 = torch.constant.int 5 + %5564 = torch.prims.convert_element_type %5563, %int5_6725 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_6726 = torch.constant.int 0 + %int0_6727 = torch.constant.int 0 + %int1_6728 = torch.constant.int 1 + %5565 = torch.aten.slice.Tensor %5562, %int0_6726, %int0_6727, %298, %int1_6728 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5565, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_6729 = torch.constant.int 1 + %int0_6730 = torch.constant.int 0 + %int9223372036854775807_6731 = torch.constant.int 9223372036854775807 + %int1_6732 = torch.constant.int 1 + %5566 = torch.aten.slice.Tensor %5565, %int1_6729, %int0_6730, %int9223372036854775807_6731, %int1_6732 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5566, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_6733 = torch.constant.int 0 + %int0_6734 = torch.constant.int 0 %int1_6735 = torch.constant.int 1 + %5567 = torch.aten.slice.Tensor %5564, %int0_6733, %int0_6734, %298, %int1_6735 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5567, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_6736 = torch.constant.int 1 - %5544 = torch.aten.add.Scalar %5514, %int1_6735, %int1_6736 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5544, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_6737 = torch.constant.int 4 - %5545 = torch.aten.mul.int %int4_6737, %398 : !torch.int, !torch.int -> !torch.int - %5546 = torch.prim.ListConstruct %5545 : (!torch.int) -> !torch.list - %5547 = torch.aten.view %5544, %5546 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %5547, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %5548 = torch.prim.ListConstruct %5547 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_6738 = torch.constant.bool false - %5549 = torch.aten.index_put %5538, %5548, %5543, %false_6738 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5549, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_6739 = torch.constant.int 32 - %int2_6740 = torch.constant.int 2 - %int32_6741 = torch.constant.int 32 - %int8_6742 = torch.constant.int 8 - %int128_6743 = torch.constant.int 128 - %5550 = torch.prim.ListConstruct %389, %int32_6739, %int2_6740, %int32_6741, %int8_6742, %int128_6743 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5551 = torch.aten.view %5549, %5550 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5551, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6744 = torch.constant.int 2097152 - %5552 = torch.prim.ListConstruct %389, %int2097152_6744 : (!torch.int, !torch.int) -> !torch.list - %5553 = torch.aten.view %5551, %5552 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5553, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_6745 = torch.constant.int -2 - %5554 = torch.aten.unsqueeze %5512, %int-2_6745 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5554, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_6746 = torch.constant.int 4 - %int8_6747 = torch.constant.int 8 - %int4_6748 = torch.constant.int 4 - %int128_6749 = torch.constant.int 128 - %5555 = torch.prim.ListConstruct %int4_6746, %5497, %int8_6747, %int4_6748, %int128_6749 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_6750 = torch.constant.bool false - %5556 = torch.aten.expand %5554, %5555, %false_6750 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5556, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_6751 = torch.constant.int 0 - %5557 = torch.aten.clone %5556, %int0_6751 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5557, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_6752 = torch.constant.int 4 - %int32_6753 = torch.constant.int 32 - %int128_6754 = torch.constant.int 128 - %5558 = torch.prim.ListConstruct %int4_6752, %5497, %int32_6753, %int128_6754 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5559 = torch.aten._unsafe_view %5557, %5558 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5559, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_6755 = torch.constant.int -2 - %5560 = torch.aten.unsqueeze %5456, %int-2_6755 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5560, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_6756 = torch.constant.int 1 - %5561 = torch.aten.size.int %5450, %int1_6756 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_6757 = torch.constant.int 4 - %int8_6758 = torch.constant.int 8 - %int4_6759 = torch.constant.int 4 - %int128_6760 = torch.constant.int 128 - %5562 = torch.prim.ListConstruct %int4_6757, %5561, %int8_6758, %int4_6759, %int128_6760 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_6761 = torch.constant.bool false - %5563 = torch.aten.expand %5560, %5562, %false_6761 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5563, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_6762 = torch.constant.int 0 - %5564 = torch.aten.clone %5563, %int0_6762 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5564, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_6763 = torch.constant.int 4 - %int32_6764 = torch.constant.int 32 - %int128_6765 = torch.constant.int 128 - %5565 = torch.prim.ListConstruct %int4_6763, %5561, %int32_6764, %int128_6765 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5566 = torch.aten._unsafe_view %5564, %5565 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5566, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int0_6737 = torch.constant.int 0 + %int9223372036854775807_6738 = torch.constant.int 9223372036854775807 + %int1_6739 = torch.constant.int 1 + %5568 = torch.aten.slice.Tensor %5567, %int1_6736, %int0_6737, %int9223372036854775807_6738, %int1_6739 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5568, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_6740 = torch.constant.int 0 + %5569 = torch.aten.unsqueeze %5566, %int0_6740 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5569, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_6741 = torch.constant.int 1 + %int0_6742 = torch.constant.int 0 + %int9223372036854775807_6743 = torch.constant.int 9223372036854775807 + %int1_6744 = torch.constant.int 1 + %5570 = torch.aten.slice.Tensor %5569, %int1_6741, %int0_6742, %int9223372036854775807_6743, %int1_6744 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5570, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_6745 = torch.constant.int 2 + %5571 = torch.aten.unsqueeze %5570, %int2_6745 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5571, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_6746 = torch.constant.int 3 + %int0_6747 = torch.constant.int 0 + %int9223372036854775807_6748 = torch.constant.int 9223372036854775807 + %int1_6749 = torch.constant.int 1 + %5572 = torch.aten.slice.Tensor %5571, %int3_6746, %int0_6747, %int9223372036854775807_6748, %int1_6749 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5572, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_6750 = torch.constant.int 4 + %int1_6751 = torch.constant.int 1 + %int1_6752 = torch.constant.int 1 + %int1_6753 = torch.constant.int 1 + %5573 = torch.prim.ListConstruct %int4_6750, %int1_6751, %int1_6752, %int1_6753 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5574 = torch.aten.repeat %5572, %5573 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %5574, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_6754 = torch.constant.int 0 + %5575 = torch.aten.unsqueeze %5568, %int0_6754 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5575, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_6755 = torch.constant.int 1 + %int0_6756 = torch.constant.int 0 + %int9223372036854775807_6757 = torch.constant.int 9223372036854775807 + %int1_6758 = torch.constant.int 1 + %5576 = torch.aten.slice.Tensor %5575, %int1_6755, %int0_6756, %int9223372036854775807_6757, %int1_6758 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5576, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_6759 = torch.constant.int 2 + %5577 = torch.aten.unsqueeze %5576, %int2_6759 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5577, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_6760 = torch.constant.int 3 + %int0_6761 = torch.constant.int 0 + %int9223372036854775807_6762 = torch.constant.int 9223372036854775807 + %int1_6763 = torch.constant.int 1 + %5578 = torch.aten.slice.Tensor %5577, %int3_6760, %int0_6761, %int9223372036854775807_6762, %int1_6763 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5578, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_6764 = torch.constant.int 4 + %int1_6765 = torch.constant.int 1 %int1_6766 = torch.constant.int 1 - %int2_6767 = torch.constant.int 2 - %5567 = torch.aten.transpose.int %5484, %int1_6766, %int2_6767 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5567, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_6768 = torch.constant.int 1 - %int2_6769 = torch.constant.int 2 - %5568 = torch.aten.transpose.int %5559, %int1_6768, %int2_6769 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5568, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_6770 = torch.constant.int 1 - %int2_6771 = torch.constant.int 2 - %5569 = torch.aten.transpose.int %5566, %int1_6770, %int2_6771 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5569, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_6772 = torch.constant.float 0.000000e+00 - %true_6773 = torch.constant.bool true - %none_6774 = torch.constant.none - %none_6775 = torch.constant.none - %5570:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5567, %5568, %5569, %float0.000000e00_6772, %true_6773, %none_6774, %none_6775) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %5570#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_6776 = torch.constant.int 1 - %int2_6777 = torch.constant.int 2 - %5571 = torch.aten.transpose.int %5570#0, %int1_6776, %int2_6777 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5571, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_6778 = torch.constant.int 4 - %int4096_6779 = torch.constant.int 4096 - %5572 = torch.prim.ListConstruct %int4_6778, %5469, %int4096_6779 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5573 = torch.aten.view %5571, %5572 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5573, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6780 = torch.constant.int -2 - %int-1_6781 = torch.constant.int -1 - %5574 = torch.aten.transpose.int %239, %int-2_6780, %int-1_6781 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6782 = torch.constant.int 4 - %5575 = torch.aten.mul.int %int4_6782, %5469 : !torch.int, !torch.int -> !torch.int - %int4096_6783 = torch.constant.int 4096 - %5576 = torch.prim.ListConstruct %5575, %int4096_6783 : (!torch.int, !torch.int) -> !torch.list - %5577 = torch.aten.view %5573, %5576 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5577, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5578 = torch.aten.mm %5577, %5574 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5578, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_6784 = torch.constant.int 4 - %int4096_6785 = torch.constant.int 4096 - %5579 = torch.prim.ListConstruct %int4_6784, %5469, %int4096_6785 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5580 = torch.aten.view %5578, %5579 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5580, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_6786 = torch.constant.int 1 - %5581 = torch.aten.add.Tensor %5419, %5580, %int1_6786 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5581, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_6787 = torch.constant.int 6 - %5582 = torch.prims.convert_element_type %5581, %int6_6787 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5582, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_6788 = torch.constant.int 2 - %5583 = torch.aten.pow.Tensor_Scalar %5582, %int2_6788 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5583, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_6789 = torch.constant.int -1 - %5584 = torch.prim.ListConstruct %int-1_6789 : (!torch.int) -> !torch.list - %true_6790 = torch.constant.bool true - %none_6791 = torch.constant.none - %5585 = torch.aten.mean.dim %5583, %5584, %true_6790, %none_6791 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5585, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_6792 = torch.constant.float 9.9999997473787516E-6 - %int1_6793 = torch.constant.int 1 - %5586 = torch.aten.add.Scalar %5585, %float9.999990e-06_6792, %int1_6793 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5586, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5587 = torch.aten.rsqrt %5586 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5587, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5588 = torch.aten.mul.Tensor %5582, %5587 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5588, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6794 = torch.constant.int 5 - %5589 = torch.prims.convert_element_type %5588, %int5_6794 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5589, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %5590 = torch.aten.mul.Tensor %240, %5589 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5590, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6795 = torch.constant.int 5 - %5591 = torch.prims.convert_element_type %5590, %int5_6795 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5591, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6796 = torch.constant.int -2 - %int-1_6797 = torch.constant.int -1 - %5592 = torch.aten.transpose.int %241, %int-2_6796, %int-1_6797 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6798 = torch.constant.int 4 - %5593 = torch.aten.mul.int %int4_6798, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6799 = torch.constant.int 4096 - %5594 = torch.prim.ListConstruct %5593, %int4096_6799 : (!torch.int, !torch.int) -> !torch.list - %5595 = torch.aten.view %5591, %5594 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5595, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5596 = torch.aten.mm %5595, %5592 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5596, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_6800 = torch.constant.int 4 - %int14336_6801 = torch.constant.int 14336 - %5597 = torch.prim.ListConstruct %int4_6800, %306, %int14336_6801 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5598 = torch.aten.view %5596, %5597 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5598, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %5599 = torch.aten.silu %5598 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5599, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_6802 = torch.constant.int -2 - %int-1_6803 = torch.constant.int -1 - %5600 = torch.aten.transpose.int %242, %int-2_6802, %int-1_6803 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6804 = torch.constant.int 4 - %5601 = torch.aten.mul.int %int4_6804, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6805 = torch.constant.int 4096 - %5602 = torch.prim.ListConstruct %5601, %int4096_6805 : (!torch.int, !torch.int) -> !torch.list - %5603 = torch.aten.view %5591, %5602 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5603, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5604 = torch.aten.mm %5603, %5600 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5604, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_6806 = torch.constant.int 4 - %int14336_6807 = torch.constant.int 14336 - %5605 = torch.prim.ListConstruct %int4_6806, %306, %int14336_6807 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5606 = torch.aten.view %5604, %5605 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5606, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %5607 = torch.aten.mul.Tensor %5599, %5606 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5607, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_6808 = torch.constant.int -2 - %int-1_6809 = torch.constant.int -1 - %5608 = torch.aten.transpose.int %243, %int-2_6808, %int-1_6809 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_6810 = torch.constant.int 1 - %5609 = torch.aten.size.int %5598, %int1_6810 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_6811 = torch.constant.int 4 - %5610 = torch.aten.mul.int %int4_6811, %5609 : !torch.int, !torch.int -> !torch.int - %int14336_6812 = torch.constant.int 14336 - %5611 = torch.prim.ListConstruct %5610, %int14336_6812 : (!torch.int, !torch.int) -> !torch.list - %5612 = torch.aten.view %5607, %5611 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5612, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %5613 = torch.aten.mm %5612, %5608 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5613, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_6813 = torch.constant.int 4 - %int4096_6814 = torch.constant.int 4096 - %5614 = torch.prim.ListConstruct %int4_6813, %5609, %int4096_6814 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5615 = torch.aten.view %5613, %5614 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5615, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_6815 = torch.constant.int 1 - %5616 = torch.aten.add.Tensor %5581, %5615, %int1_6815 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5616, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_6816 = torch.constant.int 6 - %5617 = torch.prims.convert_element_type %5616, %int6_6816 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5617, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_6817 = torch.constant.int 2 - %5618 = torch.aten.pow.Tensor_Scalar %5617, %int2_6817 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5618, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_6818 = torch.constant.int -1 - %5619 = torch.prim.ListConstruct %int-1_6818 : (!torch.int) -> !torch.list - %true_6819 = torch.constant.bool true - %none_6820 = torch.constant.none - %5620 = torch.aten.mean.dim %5618, %5619, %true_6819, %none_6820 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5620, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_6821 = torch.constant.float 9.9999997473787516E-6 - %int1_6822 = torch.constant.int 1 - %5621 = torch.aten.add.Scalar %5620, %float9.999990e-06_6821, %int1_6822 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5621, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5622 = torch.aten.rsqrt %5621 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5622, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5623 = torch.aten.mul.Tensor %5617, %5622 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5623, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6823 = torch.constant.int 5 - %5624 = torch.prims.convert_element_type %5623, %int5_6823 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5624, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %5625 = torch.aten.mul.Tensor %244, %5624 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5625, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_6824 = torch.constant.int 5 - %5626 = torch.prims.convert_element_type %5625, %int5_6824 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5626, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6825 = torch.constant.int -2 - %int-1_6826 = torch.constant.int -1 - %5627 = torch.aten.transpose.int %245, %int-2_6825, %int-1_6826 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6827 = torch.constant.int 4 - %5628 = torch.aten.mul.int %int4_6827, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6828 = torch.constant.int 4096 - %5629 = torch.prim.ListConstruct %5628, %int4096_6828 : (!torch.int, !torch.int) -> !torch.list - %5630 = torch.aten.view %5626, %5629 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5630, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5631 = torch.aten.mm %5630, %5627 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5631, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_6829 = torch.constant.int 4 - %int4096_6830 = torch.constant.int 4096 - %5632 = torch.prim.ListConstruct %int4_6829, %306, %int4096_6830 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5633 = torch.aten.view %5631, %5632 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5633, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_6831 = torch.constant.int -2 - %int-1_6832 = torch.constant.int -1 - %5634 = torch.aten.transpose.int %246, %int-2_6831, %int-1_6832 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6833 = torch.constant.int 4 - %5635 = torch.aten.mul.int %int4_6833, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6834 = torch.constant.int 4096 - %5636 = torch.prim.ListConstruct %5635, %int4096_6834 : (!torch.int, !torch.int) -> !torch.list - %5637 = torch.aten.view %5626, %5636 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5637, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5638 = torch.aten.mm %5637, %5634 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %5638, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_6835 = torch.constant.int 4 - %int1024_6836 = torch.constant.int 1024 - %5639 = torch.prim.ListConstruct %int4_6835, %306, %int1024_6836 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5640 = torch.aten.view %5638, %5639 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %5640, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_6837 = torch.constant.int -2 - %int-1_6838 = torch.constant.int -1 - %5641 = torch.aten.transpose.int %247, %int-2_6837, %int-1_6838 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6839 = torch.constant.int 4 - %5642 = torch.aten.mul.int %int4_6839, %306 : !torch.int, !torch.int -> !torch.int - %int4096_6840 = torch.constant.int 4096 - %5643 = torch.prim.ListConstruct %5642, %int4096_6840 : (!torch.int, !torch.int) -> !torch.list - %5644 = torch.aten.view %5626, %5643 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5644, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5645 = torch.aten.mm %5644, %5641 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %5645, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_6841 = torch.constant.int 4 - %int1024_6842 = torch.constant.int 1024 - %5646 = torch.prim.ListConstruct %int4_6841, %306, %int1024_6842 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5647 = torch.aten.view %5645, %5646 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %5647, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_6843 = torch.constant.int 4 - %int32_6844 = torch.constant.int 32 - %int128_6845 = torch.constant.int 128 - %5648 = torch.prim.ListConstruct %int4_6843, %306, %int32_6844, %int128_6845 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5649 = torch.aten.view %5633, %5648 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5649, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_6767 = torch.constant.int 1 + %5579 = torch.prim.ListConstruct %int4_6764, %int1_6765, %int1_6766, %int1_6767 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5580 = torch.aten.repeat %5578, %5579 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %5580, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %5581 = torch.aten.mul.Tensor %5460, %5574 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5581, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_6768 = torch.constant.int 3 + %int0_6769 = torch.constant.int 0 + %int64_6770 = torch.constant.int 64 + %int1_6771 = torch.constant.int 1 + %5582 = torch.aten.slice.Tensor %5460, %int3_6768, %int0_6769, %int64_6770, %int1_6771 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %5582, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_6772 = torch.constant.int 3 + %int64_6773 = torch.constant.int 64 + %int9223372036854775807_6774 = torch.constant.int 9223372036854775807 + %int1_6775 = torch.constant.int 1 + %5583 = torch.aten.slice.Tensor %5460, %int3_6772, %int64_6773, %int9223372036854775807_6774, %int1_6775 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %5583, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %5584 = torch.aten.neg %5583 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %5584, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %5585 = torch.prim.ListConstruct %5584, %5582 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_6776 = torch.constant.int -1 + %5586 = torch.aten.cat %5585, %int-1_6776 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5586, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %5587 = torch.aten.mul.Tensor %5586, %5580 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5587, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_6777 = torch.constant.int 1 + %5588 = torch.aten.add.Tensor %5581, %5587, %int1_6777 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5588, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_6778 = torch.constant.int 32 + %5589 = torch.aten.mul.Scalar %arg2, %int32_6778 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5589, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int19 = torch.constant.int 19 + %int1_6779 = torch.constant.int 1 + %5590 = torch.aten.add.Scalar %5589, %int19, %int1_6779 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5590, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_6780 = torch.constant.int 2 + %5591 = torch.aten.mul.Scalar %5590, %int2_6780 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5591, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_6781 = torch.constant.int 0 + %int1_6782 = torch.constant.int 1 + %5592 = torch.aten.add.Scalar %5591, %int0_6781, %int1_6782 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5592, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %5593 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %5594 = torch.aten.view %5592, %5593 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %5594, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_6783 = torch.constant.int 4 + %int32_6784 = torch.constant.int 32 + %int8_6785 = torch.constant.int 8 + %int128_6786 = torch.constant.int 128 + %5595 = torch.prim.ListConstruct %int4_6783, %296, %int32_6784, %int8_6785, %int128_6786 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5596 = torch.aten.view %5588, %5595 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5596, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_6787 = torch.constant.int 32 + %int8_6788 = torch.constant.int 8 + %int128_6789 = torch.constant.int 128 + %5597 = torch.prim.ListConstruct %504, %int32_6787, %int8_6788, %int128_6789 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5598 = torch.aten.view %5596, %5597 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %5598, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_6790 = torch.constant.int 1 + %int2_6791 = torch.constant.int 2 + %5599 = torch.aten.transpose.int %5598, %int1_6790, %int2_6791 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5599, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_6792 = torch.constant.int 5 + %5600 = torch.prims.convert_element_type %5599, %int5_6792 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5600, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_6793 = torch.constant.int 32 + %int2_6794 = torch.constant.int 2 + %int8_6795 = torch.constant.int 8 + %int32_6796 = torch.constant.int 32 + %int128_6797 = torch.constant.int 128 + %5601 = torch.prim.ListConstruct %297, %int32_6793, %int2_6794, %int8_6795, %int32_6796, %int128_6797 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5602 = torch.aten.view %5364, %5601 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5602, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_6798 = torch.constant.int 8 + %int32_6799 = torch.constant.int 32 + %int128_6800 = torch.constant.int 128 + %5603 = torch.prim.ListConstruct %497, %int8_6798, %int32_6799, %int128_6800 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5604 = torch.aten.view %5602, %5603 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5604, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %5605 = torch.prim.ListConstruct %5594 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_6801 = torch.constant.bool false + %5606 = torch.aten.index_put %5604, %5605, %5600, %false_6801 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5606, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_6802 = torch.constant.int 32 + %int2_6803 = torch.constant.int 2 + %int8_6804 = torch.constant.int 8 + %int32_6805 = torch.constant.int 32 + %int128_6806 = torch.constant.int 128 + %5607 = torch.prim.ListConstruct %297, %int32_6802, %int2_6803, %int8_6804, %int32_6805, %int128_6806 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5608 = torch.aten.view %5606, %5607 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5608, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_6807 = torch.constant.int 2097152 + %5609 = torch.prim.ListConstruct %297, %int2097152_6807 : (!torch.int, !torch.int) -> !torch.list + %5610 = torch.aten.view %5608, %5609 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5610, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_6808 = torch.constant.int 32 + %int2_6809 = torch.constant.int 2 + %int8_6810 = torch.constant.int 8 + %int32_6811 = torch.constant.int 32 + %int128_6812 = torch.constant.int 128 + %5611 = torch.prim.ListConstruct %297, %int32_6808, %int2_6809, %int8_6810, %int32_6811, %int128_6812 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5612 = torch.aten.view %5610, %5611 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5612, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_6813 = torch.constant.int 8 + %int32_6814 = torch.constant.int 32 + %int128_6815 = torch.constant.int 128 + %5613 = torch.prim.ListConstruct %497, %int8_6813, %int32_6814, %int128_6815 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5614 = torch.aten.view %5612, %5613 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5614, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_6816 = torch.constant.int 32 + %5615 = torch.aten.mul.Scalar %arg2, %int32_6816 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5615, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int19_6817 = torch.constant.int 19 + %int1_6818 = torch.constant.int 1 + %5616 = torch.aten.add.Scalar %5615, %int19_6817, %int1_6818 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5616, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_6819 = torch.constant.int 2 + %5617 = torch.aten.mul.Scalar %5616, %int2_6819 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5617, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_6820 = torch.constant.int 1 + %int1_6821 = torch.constant.int 1 + %5618 = torch.aten.add.Scalar %5617, %int1_6820, %int1_6821 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5618, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %5619 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %5620 = torch.aten.view %5618, %5619 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %5620, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_6822 = torch.constant.int 4 + %int32_6823 = torch.constant.int 32 + %int8_6824 = torch.constant.int 8 + %int128_6825 = torch.constant.int 128 + %5621 = torch.prim.ListConstruct %int4_6822, %296, %int32_6823, %int8_6824, %int128_6825 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5622 = torch.aten.view %5462, %5621 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5622, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_6826 = torch.constant.int 32 + %int8_6827 = torch.constant.int 8 + %int128_6828 = torch.constant.int 128 + %5623 = torch.prim.ListConstruct %504, %int32_6826, %int8_6827, %int128_6828 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5624 = torch.aten.view %5622, %5623 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %5624, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_6829 = torch.constant.int 1 + %int2_6830 = torch.constant.int 2 + %5625 = torch.aten.transpose.int %5624, %int1_6829, %int2_6830 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5625, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_6831 = torch.constant.int 5 + %5626 = torch.prims.convert_element_type %5625, %int5_6831 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5626, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %5627 = torch.prim.ListConstruct %5620 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_6832 = torch.constant.bool false + %5628 = torch.aten.index_put %5614, %5627, %5626, %false_6832 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5628, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_6833 = torch.constant.int 32 + %int2_6834 = torch.constant.int 2 + %int8_6835 = torch.constant.int 8 + %int32_6836 = torch.constant.int 32 + %int128_6837 = torch.constant.int 128 + %5629 = torch.prim.ListConstruct %297, %int32_6833, %int2_6834, %int8_6835, %int32_6836, %int128_6837 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5630 = torch.aten.view %5628, %5629 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5630, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_6838 = torch.constant.int 2097152 + %5631 = torch.prim.ListConstruct %297, %int2097152_6838 : (!torch.int, !torch.int) -> !torch.list + %5632 = torch.aten.view %5630, %5631 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5632, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_6839 = torch.constant.int -2 + %5633 = torch.aten.unsqueeze %5588, %int-2_6839 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5633, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_6840 = torch.constant.int 4 + %int8_6841 = torch.constant.int 8 + %int4_6842 = torch.constant.int 4 + %int128_6843 = torch.constant.int 128 + %5634 = torch.prim.ListConstruct %int4_6840, %298, %int8_6841, %int4_6842, %int128_6843 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_6844 = torch.constant.bool false + %5635 = torch.aten.expand %5633, %5634, %false_6844 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5635, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_6845 = torch.constant.int 0 + %5636 = torch.aten.clone %5635, %int0_6845 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5636, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_6846 = torch.constant.int 4 - %int8_6847 = torch.constant.int 8 + %int32_6847 = torch.constant.int 32 %int128_6848 = torch.constant.int 128 - %5650 = torch.prim.ListConstruct %int4_6846, %306, %int8_6847, %int128_6848 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5651 = torch.aten.view %5640, %5650 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5651, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_6849 = torch.constant.int 4 - %int8_6850 = torch.constant.int 8 - %int128_6851 = torch.constant.int 128 - %5652 = torch.prim.ListConstruct %int4_6849, %306, %int8_6850, %int128_6851 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5653 = torch.aten.view %5647, %5652 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5653, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_6852 = torch.constant.int 131072 - %none_6853 = torch.constant.none - %none_6854 = torch.constant.none - %cpu_6855 = torch.constant.device "cpu" - %false_6856 = torch.constant.bool false - %5654 = torch.aten.arange %int131072_6852, %none_6853, %none_6854, %cpu_6855, %false_6856 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_6857 = torch.constant.int 0 + %5637 = torch.prim.ListConstruct %int4_6846, %298, %int32_6847, %int128_6848 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5638 = torch.aten._unsafe_view %5636, %5637 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5638, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_6849 = torch.constant.int -2 + %5639 = torch.aten.unsqueeze %5462, %int-2_6849 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5639, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_6850 = torch.constant.int 4 + %int8_6851 = torch.constant.int 8 + %int4_6852 = torch.constant.int 4 + %int128_6853 = torch.constant.int 128 + %5640 = torch.prim.ListConstruct %int4_6850, %298, %int8_6851, %int4_6852, %int128_6853 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_6854 = torch.constant.bool false + %5641 = torch.aten.expand %5639, %5640, %false_6854 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5641, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_6855 = torch.constant.int 0 + %5642 = torch.aten.clone %5641, %int0_6855 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5642, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_6856 = torch.constant.int 4 + %int32_6857 = torch.constant.int 32 %int128_6858 = torch.constant.int 128 - %none_6859 = torch.constant.none - %none_6860 = torch.constant.none - %cpu_6861 = torch.constant.device "cpu" - %false_6862 = torch.constant.bool false - %5655 = torch.aten.arange.start %int0_6857, %int128_6858, %none_6859, %none_6860, %cpu_6861, %false_6862 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_6863 = torch.constant.int 2 - %5656 = torch.aten.floor_divide.Scalar %5655, %int2_6863 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_6864 = torch.constant.int 6 - %5657 = torch.prims.convert_element_type %5656, %int6_6864 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_6865 = torch.constant.int 128 - %5658 = torch.aten.div.Scalar %5657, %int128_6865 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_6866 = torch.constant.float 2.000000e+00 - %5659 = torch.aten.mul.Scalar %5658, %float2.000000e00_6866 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_6867 = torch.constant.float 5.000000e+05 - %5660 = torch.aten.pow.Scalar %float5.000000e05_6867, %5659 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %5661 = torch.aten.reciprocal %5660 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_6868 = torch.constant.float 1.000000e+00 - %5662 = torch.aten.mul.Scalar %5661, %float1.000000e00_6868 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_6869 = torch.constant.int 1 - %5663 = torch.aten.unsqueeze %5654, %int1_6869 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_6870 = torch.constant.int 0 - %5664 = torch.aten.unsqueeze %5662, %int0_6870 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %5665 = torch.aten.mul.Tensor %5663, %5664 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_6871 = torch.constant.int 1 - %5666 = torch.aten.size.int %5633, %int1_6871 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_6872 = torch.constant.int 0 - %5667 = torch.aten.add.int %int0_6872, %5666 : !torch.int, !torch.int -> !torch.int - %int0_6873 = torch.constant.int 0 - %int0_6874 = torch.constant.int 0 - %int1_6875 = torch.constant.int 1 - %5668 = torch.aten.slice.Tensor %5665, %int0_6873, %int0_6874, %5667, %int1_6875 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5668, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_6876 = torch.constant.int 1 - %int0_6877 = torch.constant.int 0 - %int9223372036854775807_6878 = torch.constant.int 9223372036854775807 - %int1_6879 = torch.constant.int 1 - %5669 = torch.aten.slice.Tensor %5668, %int1_6876, %int0_6877, %int9223372036854775807_6878, %int1_6879 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5669, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_6880 = torch.constant.int 1 - %int0_6881 = torch.constant.int 0 - %int9223372036854775807_6882 = torch.constant.int 9223372036854775807 - %int1_6883 = torch.constant.int 1 - %5670 = torch.aten.slice.Tensor %5669, %int1_6880, %int0_6881, %int9223372036854775807_6882, %int1_6883 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5670, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_6884 = torch.constant.int 0 - %5671 = torch.aten.unsqueeze %5670, %int0_6884 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5671, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %5643 = torch.prim.ListConstruct %int4_6856, %298, %int32_6857, %int128_6858 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5644 = torch.aten._unsafe_view %5642, %5643 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5644, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_6859 = torch.constant.int 1 + %int2_6860 = torch.constant.int 2 + %5645 = torch.aten.transpose.int %5525, %int1_6859, %int2_6860 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5645, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_6861 = torch.constant.int 1 + %int2_6862 = torch.constant.int 2 + %5646 = torch.aten.transpose.int %5638, %int1_6861, %int2_6862 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5646, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_6863 = torch.constant.int 1 + %int2_6864 = torch.constant.int 2 + %5647 = torch.aten.transpose.int %5644, %int1_6863, %int2_6864 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5647, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_6865 = torch.constant.float 0.000000e+00 + %false_6866 = torch.constant.bool false + %none_6867 = torch.constant.none + %5648:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5645, %5646, %5647, %float0.000000e00_6865, %false_6866, %327, %none_6867) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %5648#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_6868 = torch.constant.int 1 + %int2_6869 = torch.constant.int 2 + %5649 = torch.aten.transpose.int %5648#0, %int1_6868, %int2_6869 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5649, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_6870 = torch.constant.int 4 + %int4096_6871 = torch.constant.int 4096 + %5650 = torch.prim.ListConstruct %int4_6870, %298, %int4096_6871 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5651 = torch.aten.view %5649, %5650 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5651, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_6872 = torch.constant.int -2 + %int-1_6873 = torch.constant.int -1 + %5652 = torch.aten.transpose.int %177, %int-2_6872, %int-1_6873 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_6874 = torch.constant.int 5 + %5653 = torch.prims.convert_element_type %5652, %int5_6874 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_6875 = torch.constant.int 4096 + %5654 = torch.prim.ListConstruct %342, %int4096_6875 : (!torch.int, !torch.int) -> !torch.list + %5655 = torch.aten.view %5651, %5654 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5655, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5656 = torch.aten.mm %5655, %5653 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5656, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_6876 = torch.constant.int 4 + %int4096_6877 = torch.constant.int 4096 + %5657 = torch.prim.ListConstruct %int4_6876, %298, %int4096_6877 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5658 = torch.aten.view %5656, %5657 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5658, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_6878 = torch.constant.int 1 + %5659 = torch.aten.add.Tensor %5425, %5658, %int1_6878 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5659, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_6879 = torch.constant.int 6 + %5660 = torch.prims.convert_element_type %5659, %int6_6879 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5660, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_6880 = torch.constant.int 2 + %5661 = torch.aten.pow.Tensor_Scalar %5660, %int2_6880 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5661, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_6881 = torch.constant.int -1 + %5662 = torch.prim.ListConstruct %int-1_6881 : (!torch.int) -> !torch.list + %true_6882 = torch.constant.bool true + %none_6883 = torch.constant.none + %5663 = torch.aten.mean.dim %5661, %5662, %true_6882, %none_6883 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5663, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_6884 = torch.constant.float 9.9999997473787516E-6 %int1_6885 = torch.constant.int 1 - %int0_6886 = torch.constant.int 0 - %int9223372036854775807_6887 = torch.constant.int 9223372036854775807 - %int1_6888 = torch.constant.int 1 - %5672 = torch.aten.slice.Tensor %5671, %int1_6885, %int0_6886, %int9223372036854775807_6887, %int1_6888 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5672, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_6889 = torch.constant.int 2 - %int0_6890 = torch.constant.int 0 - %int9223372036854775807_6891 = torch.constant.int 9223372036854775807 - %int1_6892 = torch.constant.int 1 - %5673 = torch.aten.slice.Tensor %5672, %int2_6889, %int0_6890, %int9223372036854775807_6891, %int1_6892 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5673, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_6893 = torch.constant.int 4 - %int1_6894 = torch.constant.int 1 - %int1_6895 = torch.constant.int 1 - %5674 = torch.prim.ListConstruct %int4_6893, %int1_6894, %int1_6895 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5675 = torch.aten.repeat %5673, %5674 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %5675, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_6896 = torch.constant.int 6 - %5676 = torch.prims.convert_element_type %5649, %int6_6896 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %5676, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %5677 = torch_c.to_builtin_tensor %5676 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %5678 = torch_c.to_builtin_tensor %5675 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %5679 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%5677, %5678) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %5680 = torch_c.from_builtin_tensor %5679 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %5680, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_6897 = torch.constant.int 5 - %5681 = torch.prims.convert_element_type %5680, %int5_6897 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5681, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_6898 = torch.constant.int 131072 - %none_6899 = torch.constant.none - %none_6900 = torch.constant.none - %cpu_6901 = torch.constant.device "cpu" - %false_6902 = torch.constant.bool false - %5682 = torch.aten.arange %int131072_6898, %none_6899, %none_6900, %cpu_6901, %false_6902 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_6903 = torch.constant.int 0 - %int128_6904 = torch.constant.int 128 - %none_6905 = torch.constant.none - %none_6906 = torch.constant.none - %cpu_6907 = torch.constant.device "cpu" - %false_6908 = torch.constant.bool false - %5683 = torch.aten.arange.start %int0_6903, %int128_6904, %none_6905, %none_6906, %cpu_6907, %false_6908 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_6909 = torch.constant.int 2 - %5684 = torch.aten.floor_divide.Scalar %5683, %int2_6909 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_6910 = torch.constant.int 6 - %5685 = torch.prims.convert_element_type %5684, %int6_6910 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_6911 = torch.constant.int 128 - %5686 = torch.aten.div.Scalar %5685, %int128_6911 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_6912 = torch.constant.float 2.000000e+00 - %5687 = torch.aten.mul.Scalar %5686, %float2.000000e00_6912 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_6913 = torch.constant.float 5.000000e+05 - %5688 = torch.aten.pow.Scalar %float5.000000e05_6913, %5687 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %5689 = torch.aten.reciprocal %5688 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_6914 = torch.constant.float 1.000000e+00 - %5690 = torch.aten.mul.Scalar %5689, %float1.000000e00_6914 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_6915 = torch.constant.int 1 - %5691 = torch.aten.unsqueeze %5682, %int1_6915 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_6916 = torch.constant.int 0 - %5692 = torch.aten.unsqueeze %5690, %int0_6916 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %5693 = torch.aten.mul.Tensor %5691, %5692 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_6917 = torch.constant.int 1 - %5694 = torch.aten.size.int %5640, %int1_6917 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_6918 = torch.constant.int 0 - %5695 = torch.aten.add.int %int0_6918, %5694 : !torch.int, !torch.int -> !torch.int - %int0_6919 = torch.constant.int 0 - %int0_6920 = torch.constant.int 0 - %int1_6921 = torch.constant.int 1 - %5696 = torch.aten.slice.Tensor %5693, %int0_6919, %int0_6920, %5695, %int1_6921 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5696, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_6922 = torch.constant.int 1 - %int0_6923 = torch.constant.int 0 - %int9223372036854775807_6924 = torch.constant.int 9223372036854775807 - %int1_6925 = torch.constant.int 1 - %5697 = torch.aten.slice.Tensor %5696, %int1_6922, %int0_6923, %int9223372036854775807_6924, %int1_6925 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5697, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_6926 = torch.constant.int 1 - %int0_6927 = torch.constant.int 0 - %int9223372036854775807_6928 = torch.constant.int 9223372036854775807 - %int1_6929 = torch.constant.int 1 - %5698 = torch.aten.slice.Tensor %5697, %int1_6926, %int0_6927, %int9223372036854775807_6928, %int1_6929 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5698, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_6930 = torch.constant.int 0 - %5699 = torch.aten.unsqueeze %5698, %int0_6930 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5699, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_6931 = torch.constant.int 1 - %int0_6932 = torch.constant.int 0 - %int9223372036854775807_6933 = torch.constant.int 9223372036854775807 - %int1_6934 = torch.constant.int 1 - %5700 = torch.aten.slice.Tensor %5699, %int1_6931, %int0_6932, %int9223372036854775807_6933, %int1_6934 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5700, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_6935 = torch.constant.int 2 - %int0_6936 = torch.constant.int 0 - %int9223372036854775807_6937 = torch.constant.int 9223372036854775807 - %int1_6938 = torch.constant.int 1 - %5701 = torch.aten.slice.Tensor %5700, %int2_6935, %int0_6936, %int9223372036854775807_6937, %int1_6938 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5701, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_6939 = torch.constant.int 4 - %int1_6940 = torch.constant.int 1 - %int1_6941 = torch.constant.int 1 - %5702 = torch.prim.ListConstruct %int4_6939, %int1_6940, %int1_6941 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5703 = torch.aten.repeat %5701, %5702 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %5703, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_6942 = torch.constant.int 6 - %5704 = torch.prims.convert_element_type %5651, %int6_6942 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %5704, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %5705 = torch_c.to_builtin_tensor %5704 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %5706 = torch_c.to_builtin_tensor %5703 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %5707 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%5705, %5706) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %5708 = torch_c.from_builtin_tensor %5707 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %5708, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_6943 = torch.constant.int 5 - %5709 = torch.prims.convert_element_type %5708, %int5_6943 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5709, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_6944 = torch.constant.int 64 - %5710 = torch.aten.mul.Scalar %arg2, %int64_6944 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5710, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int54 = torch.constant.int 54 - %int1_6945 = torch.constant.int 1 - %5711 = torch.aten.add.Scalar %5710, %int54, %int1_6945 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5711, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_6946 = torch.constant.int 4 - %int32_6947 = torch.constant.int 32 - %int8_6948 = torch.constant.int 8 + %5664 = torch.aten.add.Scalar %5663, %float9.999990e-06_6884, %int1_6885 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5664, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5665 = torch.aten.rsqrt %5664 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5665, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5666 = torch.aten.mul.Tensor %5660, %5665 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5666, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_6886 = torch.constant.int 5 + %5667 = torch.prims.convert_element_type %5666, %int5_6886 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5667, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %5668 = torch.aten.mul.Tensor %178, %5667 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5668, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_6887 = torch.constant.int 5 + %5669 = torch.prims.convert_element_type %5668, %int5_6887 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5669, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_6888 = torch.constant.int -2 + %int-1_6889 = torch.constant.int -1 + %5670 = torch.aten.transpose.int %179, %int-2_6888, %int-1_6889 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_6890 = torch.constant.int 5 + %5671 = torch.prims.convert_element_type %5670, %int5_6890 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_6891 = torch.constant.int 4096 + %5672 = torch.prim.ListConstruct %342, %int4096_6891 : (!torch.int, !torch.int) -> !torch.list + %5673 = torch.aten.view %5669, %5672 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5673, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5674 = torch.aten.mm %5673, %5671 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %5674, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_6892 = torch.constant.int 4 + %int14336_6893 = torch.constant.int 14336 + %5675 = torch.prim.ListConstruct %int4_6892, %298, %int14336_6893 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5676 = torch.aten.view %5674, %5675 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5676, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %5677 = torch.aten.silu %5676 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5677, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_6894 = torch.constant.int -2 + %int-1_6895 = torch.constant.int -1 + %5678 = torch.aten.transpose.int %180, %int-2_6894, %int-1_6895 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_6896 = torch.constant.int 5 + %5679 = torch.prims.convert_element_type %5678, %int5_6896 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_6897 = torch.constant.int 4096 + %5680 = torch.prim.ListConstruct %342, %int4096_6897 : (!torch.int, !torch.int) -> !torch.list + %5681 = torch.aten.view %5669, %5680 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5681, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5682 = torch.aten.mm %5681, %5679 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %5682, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_6898 = torch.constant.int 4 + %int14336_6899 = torch.constant.int 14336 + %5683 = torch.prim.ListConstruct %int4_6898, %298, %int14336_6899 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5684 = torch.aten.view %5682, %5683 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5684, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %5685 = torch.aten.mul.Tensor %5677, %5684 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5685, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_6900 = torch.constant.int -2 + %int-1_6901 = torch.constant.int -1 + %5686 = torch.aten.transpose.int %181, %int-2_6900, %int-1_6901 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_6902 = torch.constant.int 5 + %5687 = torch.prims.convert_element_type %5686, %int5_6902 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_6903 = torch.constant.int 14336 + %5688 = torch.prim.ListConstruct %342, %int14336_6903 : (!torch.int, !torch.int) -> !torch.list + %5689 = torch.aten.view %5685, %5688 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %5689, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %5690 = torch.aten.mm %5689, %5687 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5690, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_6904 = torch.constant.int 4 + %int4096_6905 = torch.constant.int 4096 + %5691 = torch.prim.ListConstruct %int4_6904, %298, %int4096_6905 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5692 = torch.aten.view %5690, %5691 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5692, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_6906 = torch.constant.int 1 + %5693 = torch.aten.add.Tensor %5659, %5692, %int1_6906 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5693, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_6907 = torch.constant.int 6 + %5694 = torch.prims.convert_element_type %5693, %int6_6907 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5694, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_6908 = torch.constant.int 2 + %5695 = torch.aten.pow.Tensor_Scalar %5694, %int2_6908 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5695, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_6909 = torch.constant.int -1 + %5696 = torch.prim.ListConstruct %int-1_6909 : (!torch.int) -> !torch.list + %true_6910 = torch.constant.bool true + %none_6911 = torch.constant.none + %5697 = torch.aten.mean.dim %5695, %5696, %true_6910, %none_6911 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5697, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_6912 = torch.constant.float 9.9999997473787516E-6 + %int1_6913 = torch.constant.int 1 + %5698 = torch.aten.add.Scalar %5697, %float9.999990e-06_6912, %int1_6913 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5698, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5699 = torch.aten.rsqrt %5698 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5699, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5700 = torch.aten.mul.Tensor %5694, %5699 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5700, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_6914 = torch.constant.int 5 + %5701 = torch.prims.convert_element_type %5700, %int5_6914 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5701, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %5702 = torch.aten.mul.Tensor %182, %5701 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5702, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_6915 = torch.constant.int 5 + %5703 = torch.prims.convert_element_type %5702, %int5_6915 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5703, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_6916 = torch.constant.int -2 + %int-1_6917 = torch.constant.int -1 + %5704 = torch.aten.transpose.int %183, %int-2_6916, %int-1_6917 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_6918 = torch.constant.int 5 + %5705 = torch.prims.convert_element_type %5704, %int5_6918 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_6919 = torch.constant.int 4096 + %5706 = torch.prim.ListConstruct %342, %int4096_6919 : (!torch.int, !torch.int) -> !torch.list + %5707 = torch.aten.view %5703, %5706 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5707, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5708 = torch.aten.mm %5707, %5705 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5708, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_6920 = torch.constant.int 4 + %int4096_6921 = torch.constant.int 4096 + %5709 = torch.prim.ListConstruct %int4_6920, %298, %int4096_6921 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5710 = torch.aten.view %5708, %5709 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5710, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_6922 = torch.constant.int -2 + %int-1_6923 = torch.constant.int -1 + %5711 = torch.aten.transpose.int %184, %int-2_6922, %int-1_6923 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6924 = torch.constant.int 5 + %5712 = torch.prims.convert_element_type %5711, %int5_6924 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_6925 = torch.constant.int 4096 + %5713 = torch.prim.ListConstruct %342, %int4096_6925 : (!torch.int, !torch.int) -> !torch.list + %5714 = torch.aten.view %5703, %5713 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5714, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5715 = torch.aten.mm %5714, %5712 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %5715, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_6926 = torch.constant.int 4 + %int1024_6927 = torch.constant.int 1024 + %5716 = torch.prim.ListConstruct %int4_6926, %298, %int1024_6927 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5717 = torch.aten.view %5715, %5716 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %5717, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_6928 = torch.constant.int -2 + %int-1_6929 = torch.constant.int -1 + %5718 = torch.aten.transpose.int %185, %int-2_6928, %int-1_6929 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6930 = torch.constant.int 5 + %5719 = torch.prims.convert_element_type %5718, %int5_6930 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_6931 = torch.constant.int 4096 + %5720 = torch.prim.ListConstruct %342, %int4096_6931 : (!torch.int, !torch.int) -> !torch.list + %5721 = torch.aten.view %5703, %5720 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5721, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5722 = torch.aten.mm %5721, %5719 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %5722, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_6932 = torch.constant.int 4 + %int1024_6933 = torch.constant.int 1024 + %5723 = torch.prim.ListConstruct %int4_6932, %298, %int1024_6933 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5724 = torch.aten.view %5722, %5723 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %5724, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_6934 = torch.constant.int 4 + %int32_6935 = torch.constant.int 32 + %int128_6936 = torch.constant.int 128 + %5725 = torch.prim.ListConstruct %int4_6934, %298, %int32_6935, %int128_6936 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5726 = torch.aten.view %5710, %5725 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5726, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_6937 = torch.constant.int 4 + %int8_6938 = torch.constant.int 8 + %int128_6939 = torch.constant.int 128 + %5727 = torch.prim.ListConstruct %int4_6937, %298, %int8_6938, %int128_6939 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5728 = torch.aten.view %5717, %5727 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5728, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_6940 = torch.constant.int 4 + %int8_6941 = torch.constant.int 8 + %int128_6942 = torch.constant.int 128 + %5729 = torch.prim.ListConstruct %int4_6940, %298, %int8_6941, %int128_6942 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5730 = torch.aten.view %5724, %5729 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5730, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_6943 = torch.constant.int 131072 + %none_6944 = torch.constant.none + %none_6945 = torch.constant.none + %cpu_6946 = torch.constant.device "cpu" + %false_6947 = torch.constant.bool false + %5731 = torch.aten.arange %int131072_6943, %none_6944, %none_6945, %cpu_6946, %false_6947 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_6948 = torch.constant.int 0 %int128_6949 = torch.constant.int 128 - %5712 = torch.prim.ListConstruct %int4_6946, %398, %int32_6947, %int8_6948, %int128_6949 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5713 = torch.aten.view %5709, %5712 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5713, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_6950 = torch.constant.int 4 - %5714 = torch.aten.mul.int %int4_6950, %398 : !torch.int, !torch.int -> !torch.int - %int32_6951 = torch.constant.int 32 - %int8_6952 = torch.constant.int 8 - %int128_6953 = torch.constant.int 128 - %5715 = torch.prim.ListConstruct %5714, %int32_6951, %int8_6952, %int128_6953 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5716 = torch.aten.view %5713, %5715 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5716, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_6954 = torch.constant.int 4 - %5717 = torch.aten.mul.int %int4_6954, %398 : !torch.int, !torch.int -> !torch.int - %5718 = torch.prim.ListConstruct %5717 : (!torch.int) -> !torch.list - %5719 = torch.aten.view %5711, %5718 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %5719, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_6955 = torch.constant.int 32 - %int2_6956 = torch.constant.int 2 - %int32_6957 = torch.constant.int 32 - %int8_6958 = torch.constant.int 8 - %int128_6959 = torch.constant.int 128 - %5720 = torch.prim.ListConstruct %389, %int32_6955, %int2_6956, %int32_6957, %int8_6958, %int128_6959 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5721 = torch.aten.view %5553, %5720 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5721, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_6960 = torch.constant.int 32 - %5722 = torch.aten.mul.int %389, %int32_6960 : !torch.int, !torch.int -> !torch.int - %int2_6961 = torch.constant.int 2 - %5723 = torch.aten.mul.int %5722, %int2_6961 : !torch.int, !torch.int -> !torch.int - %int32_6962 = torch.constant.int 32 - %int8_6963 = torch.constant.int 8 - %int128_6964 = torch.constant.int 128 - %5724 = torch.prim.ListConstruct %5723, %int32_6962, %int8_6963, %int128_6964 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5725 = torch.aten.view %5721, %5724 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5725, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %5726 = torch.prim.ListConstruct %5719 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_6965 = torch.constant.bool false - %5727 = torch.aten.index_put %5725, %5726, %5716, %false_6965 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5727, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_6966 = torch.constant.int 32 - %int2_6967 = torch.constant.int 2 - %int32_6968 = torch.constant.int 32 - %int8_6969 = torch.constant.int 8 - %int128_6970 = torch.constant.int 128 - %5728 = torch.prim.ListConstruct %389, %int32_6966, %int2_6967, %int32_6968, %int8_6969, %int128_6970 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5729 = torch.aten.view %5727, %5728 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5729, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6971 = torch.constant.int 2097152 - %5730 = torch.prim.ListConstruct %389, %int2097152_6971 : (!torch.int, !torch.int) -> !torch.list - %5731 = torch.aten.view %5729, %5730 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5731, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_6972 = torch.constant.int 32 - %int2_6973 = torch.constant.int 2 - %int32_6974 = torch.constant.int 32 - %int8_6975 = torch.constant.int 8 - %int128_6976 = torch.constant.int 128 - %5732 = torch.prim.ListConstruct %389, %int32_6972, %int2_6973, %int32_6974, %int8_6975, %int128_6976 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5733 = torch.aten.view %5731, %5732 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5733, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_6977 = torch.constant.int 32 - %int8_6978 = torch.constant.int 8 - %int128_6979 = torch.constant.int 128 - %5734 = torch.prim.ListConstruct %5723, %int32_6977, %int8_6978, %int128_6979 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5735 = torch.aten.view %5733, %5734 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5735, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_6980 = torch.constant.int 4 - %int32_6981 = torch.constant.int 32 - %int8_6982 = torch.constant.int 8 - %int128_6983 = torch.constant.int 128 - %5736 = torch.prim.ListConstruct %int4_6980, %398, %int32_6981, %int8_6982, %int128_6983 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5737 = torch.aten.view %5653, %5736 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5737, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_6984 = torch.constant.int 4 - %5738 = torch.aten.mul.int %int4_6984, %398 : !torch.int, !torch.int -> !torch.int - %int32_6985 = torch.constant.int 32 - %int8_6986 = torch.constant.int 8 - %int128_6987 = torch.constant.int 128 - %5739 = torch.prim.ListConstruct %5738, %int32_6985, %int8_6986, %int128_6987 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5740 = torch.aten.view %5737, %5739 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5740, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_6988 = torch.constant.int 1 + %int2_6950 = torch.constant.int 2 + %int4_6951 = torch.constant.int 4 + %none_6952 = torch.constant.none + %cpu_6953 = torch.constant.device "cpu" + %false_6954 = torch.constant.bool false + %5732 = torch.aten.arange.start_step %int0_6948, %int128_6949, %int2_6950, %int4_6951, %none_6952, %cpu_6953, %false_6954 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_6955 = torch.constant.int 6 + %5733 = torch.prims.convert_element_type %5732, %int6_6955 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_6956 = torch.constant.int 128 + %5734 = torch.aten.div.Scalar %5733, %int128_6956 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_6957 = torch.constant.float 5.000000e+05 + %5735 = torch.aten.pow.Scalar %float5.000000e05_6957, %5734 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5736 = torch.aten.reciprocal %5735 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_6958 = torch.constant.float 1.000000e+00 + %5737 = torch.aten.mul.Scalar %5736, %float1.000000e00_6958 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %5738 = torch.aten.reciprocal %5737 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_6959 = torch.constant.float 6.2831853071795862 + %5739 = torch.aten.mul.Scalar %5738, %float6.283190e00_6959 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_6960 = torch.constant.float 8.192000e+03 + %5740 = torch.aten.gt.Scalar %5739, %float8.192000e03_6960 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_6961 = torch.constant.int 8 + %5741 = torch.aten.div.Scalar %5737, %int8_6961 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %5742 = torch.aten.where.self %5740, %5741, %5737 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5743 = torch.aten.reciprocal %5739 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_6962 = torch.constant.int 8192 + %5744 = torch.aten.mul.Scalar %5743, %int8192_6962 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_6963 = torch.constant.int 1 + %int1_6964 = torch.constant.int 1 + %5745 = torch.aten.sub.Scalar %5744, %int1_6963, %int1_6964 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_6965 = torch.constant.int 3 + %5746 = torch.aten.div.Scalar %5745, %int3_6965 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_6966 = torch.constant.int 1 + %int1_6967 = torch.constant.int 1 + %5747 = torch.aten.rsub.Scalar %5746, %int1_6966, %int1_6967 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %5748 = torch.aten.mul.Tensor %5747, %5742 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_6968 = torch.constant.int 8 + %5749 = torch.aten.div.Scalar %5748, %int8_6968 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %5750 = torch.aten.mul.Tensor %5746, %5742 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_6969 = torch.constant.int 1 + %5751 = torch.aten.add.Tensor %5749, %5750, %int1_6969 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_6970 = torch.constant.float 2.048000e+03 + %5752 = torch.aten.lt.Scalar %5739, %float2.048000e03_6970 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %5753 = torch.aten.bitwise_not %5752 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_6971 = torch.constant.float 8.192000e+03 + %5754 = torch.aten.gt.Scalar %5739, %float8.192000e03_6971 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %5755 = torch.aten.bitwise_not %5754 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %5756 = torch.aten.mul.Tensor %5753, %5755 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %5757 = torch.aten.where.self %5756, %5751, %5742 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5758 = torch.prim.ListConstruct %5757, %5757 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_6972 = torch.constant.int -1 + %5759 = torch.aten.cat %5758, %int-1_6972 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_6973 = torch.constant.int 6 + %5760 = torch.prims.convert_element_type %5759, %int6_6973 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_6974 = torch.constant.int 1 + %5761 = torch.aten.unsqueeze %5731, %int1_6974 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_6975 = torch.constant.int 6 + %5762 = torch.prims.convert_element_type %5761, %int6_6975 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_6976 = torch.constant.int 0 + %5763 = torch.aten.unsqueeze %5760, %int0_6976 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_6977 = torch.constant.int 6 + %5764 = torch.prims.convert_element_type %5763, %int6_6977 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %5765 = torch.aten.mul.Tensor %5762, %5764 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %5766 = torch.aten.cos %5765 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_6978 = torch.constant.int 5 + %5767 = torch.prims.convert_element_type %5766, %int5_6978 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %5768 = torch.aten.sin %5765 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_6979 = torch.constant.int 5 + %5769 = torch.prims.convert_element_type %5768, %int5_6979 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_6980 = torch.constant.int 0 + %int0_6981 = torch.constant.int 0 + %int1_6982 = torch.constant.int 1 + %5770 = torch.aten.slice.Tensor %5767, %int0_6980, %int0_6981, %298, %int1_6982 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5770, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_6983 = torch.constant.int 1 + %int0_6984 = torch.constant.int 0 + %int9223372036854775807_6985 = torch.constant.int 9223372036854775807 + %int1_6986 = torch.constant.int 1 + %5771 = torch.aten.slice.Tensor %5770, %int1_6983, %int0_6984, %int9223372036854775807_6985, %int1_6986 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5771, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_6987 = torch.constant.int 0 + %int0_6988 = torch.constant.int 0 %int1_6989 = torch.constant.int 1 - %5741 = torch.aten.add.Scalar %5711, %int1_6988, %int1_6989 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5741, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_6990 = torch.constant.int 4 - %5742 = torch.aten.mul.int %int4_6990, %398 : !torch.int, !torch.int -> !torch.int - %5743 = torch.prim.ListConstruct %5742 : (!torch.int) -> !torch.list - %5744 = torch.aten.view %5741, %5743 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %5744, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %5745 = torch.prim.ListConstruct %5744 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_6991 = torch.constant.bool false - %5746 = torch.aten.index_put %5735, %5745, %5740, %false_6991 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5746, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_6992 = torch.constant.int 32 - %int2_6993 = torch.constant.int 2 - %int32_6994 = torch.constant.int 32 - %int8_6995 = torch.constant.int 8 - %int128_6996 = torch.constant.int 128 - %5747 = torch.prim.ListConstruct %389, %int32_6992, %int2_6993, %int32_6994, %int8_6995, %int128_6996 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5748 = torch.aten.view %5746, %5747 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5748, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6997 = torch.constant.int 2097152 - %5749 = torch.prim.ListConstruct %389, %int2097152_6997 : (!torch.int, !torch.int) -> !torch.list - %5750 = torch.aten.view %5748, %5749 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5750, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_6998 = torch.constant.int -2 - %5751 = torch.aten.unsqueeze %5709, %int-2_6998 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5751, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_6999 = torch.constant.int 4 - %int8_7000 = torch.constant.int 8 - %int4_7001 = torch.constant.int 4 - %int128_7002 = torch.constant.int 128 - %5752 = torch.prim.ListConstruct %int4_6999, %5694, %int8_7000, %int4_7001, %int128_7002 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7003 = torch.constant.bool false - %5753 = torch.aten.expand %5751, %5752, %false_7003 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5753, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_7004 = torch.constant.int 0 - %5754 = torch.aten.clone %5753, %int0_7004 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5754, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7005 = torch.constant.int 4 - %int32_7006 = torch.constant.int 32 - %int128_7007 = torch.constant.int 128 - %5755 = torch.prim.ListConstruct %int4_7005, %5694, %int32_7006, %int128_7007 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5756 = torch.aten._unsafe_view %5754, %5755 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5756, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_7008 = torch.constant.int -2 - %5757 = torch.aten.unsqueeze %5653, %int-2_7008 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5757, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %5772 = torch.aten.slice.Tensor %5769, %int0_6987, %int0_6988, %298, %int1_6989 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5772, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_6990 = torch.constant.int 1 + %int0_6991 = torch.constant.int 0 + %int9223372036854775807_6992 = torch.constant.int 9223372036854775807 + %int1_6993 = torch.constant.int 1 + %5773 = torch.aten.slice.Tensor %5772, %int1_6990, %int0_6991, %int9223372036854775807_6992, %int1_6993 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5773, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_6994 = torch.constant.int 0 + %5774 = torch.aten.unsqueeze %5771, %int0_6994 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5774, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_6995 = torch.constant.int 1 + %int0_6996 = torch.constant.int 0 + %int9223372036854775807_6997 = torch.constant.int 9223372036854775807 + %int1_6998 = torch.constant.int 1 + %5775 = torch.aten.slice.Tensor %5774, %int1_6995, %int0_6996, %int9223372036854775807_6997, %int1_6998 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5775, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_6999 = torch.constant.int 2 + %5776 = torch.aten.unsqueeze %5775, %int2_6999 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5776, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_7000 = torch.constant.int 3 + %int0_7001 = torch.constant.int 0 + %int9223372036854775807_7002 = torch.constant.int 9223372036854775807 + %int1_7003 = torch.constant.int 1 + %5777 = torch.aten.slice.Tensor %5776, %int3_7000, %int0_7001, %int9223372036854775807_7002, %int1_7003 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5777, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_7004 = torch.constant.int 4 + %int1_7005 = torch.constant.int 1 + %int1_7006 = torch.constant.int 1 + %int1_7007 = torch.constant.int 1 + %5778 = torch.prim.ListConstruct %int4_7004, %int1_7005, %int1_7006, %int1_7007 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5779 = torch.aten.repeat %5777, %5778 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %5779, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_7008 = torch.constant.int 0 + %5780 = torch.aten.unsqueeze %5773, %int0_7008 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5780, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int1_7009 = torch.constant.int 1 - %5758 = torch.aten.size.int %5647, %int1_7009 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_7010 = torch.constant.int 4 - %int8_7011 = torch.constant.int 8 - %int4_7012 = torch.constant.int 4 - %int128_7013 = torch.constant.int 128 - %5759 = torch.prim.ListConstruct %int4_7010, %5758, %int8_7011, %int4_7012, %int128_7013 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7014 = torch.constant.bool false - %5760 = torch.aten.expand %5757, %5759, %false_7014 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5760, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_7010 = torch.constant.int 0 + %int9223372036854775807_7011 = torch.constant.int 9223372036854775807 + %int1_7012 = torch.constant.int 1 + %5781 = torch.aten.slice.Tensor %5780, %int1_7009, %int0_7010, %int9223372036854775807_7011, %int1_7012 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5781, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_7013 = torch.constant.int 2 + %5782 = torch.aten.unsqueeze %5781, %int2_7013 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5782, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_7014 = torch.constant.int 3 %int0_7015 = torch.constant.int 0 - %5761 = torch.aten.clone %5760, %int0_7015 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5761, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7016 = torch.constant.int 4 - %int32_7017 = torch.constant.int 32 - %int128_7018 = torch.constant.int 128 - %5762 = torch.prim.ListConstruct %int4_7016, %5758, %int32_7017, %int128_7018 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5763 = torch.aten._unsafe_view %5761, %5762 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5763, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int9223372036854775807_7016 = torch.constant.int 9223372036854775807 + %int1_7017 = torch.constant.int 1 + %5783 = torch.aten.slice.Tensor %5782, %int3_7014, %int0_7015, %int9223372036854775807_7016, %int1_7017 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5783, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_7018 = torch.constant.int 4 %int1_7019 = torch.constant.int 1 - %int2_7020 = torch.constant.int 2 - %5764 = torch.aten.transpose.int %5681, %int1_7019, %int2_7020 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5764, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_7020 = torch.constant.int 1 %int1_7021 = torch.constant.int 1 - %int2_7022 = torch.constant.int 2 - %5765 = torch.aten.transpose.int %5756, %int1_7021, %int2_7022 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5765, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_7023 = torch.constant.int 1 - %int2_7024 = torch.constant.int 2 - %5766 = torch.aten.transpose.int %5763, %int1_7023, %int2_7024 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5766, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_7025 = torch.constant.float 0.000000e+00 - %true_7026 = torch.constant.bool true - %none_7027 = torch.constant.none - %none_7028 = torch.constant.none - %5767:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5764, %5765, %5766, %float0.000000e00_7025, %true_7026, %none_7027, %none_7028) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %5767#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %5784 = torch.prim.ListConstruct %int4_7018, %int1_7019, %int1_7020, %int1_7021 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5785 = torch.aten.repeat %5783, %5784 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %5785, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %5786 = torch.aten.mul.Tensor %5726, %5779 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5786, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_7022 = torch.constant.int 3 + %int0_7023 = torch.constant.int 0 + %int64_7024 = torch.constant.int 64 + %int1_7025 = torch.constant.int 1 + %5787 = torch.aten.slice.Tensor %5726, %int3_7022, %int0_7023, %int64_7024, %int1_7025 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %5787, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_7026 = torch.constant.int 3 + %int64_7027 = torch.constant.int 64 + %int9223372036854775807_7028 = torch.constant.int 9223372036854775807 %int1_7029 = torch.constant.int 1 - %int2_7030 = torch.constant.int 2 - %5768 = torch.aten.transpose.int %5767#0, %int1_7029, %int2_7030 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5768, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_7031 = torch.constant.int 4 - %int4096_7032 = torch.constant.int 4096 - %5769 = torch.prim.ListConstruct %int4_7031, %5666, %int4096_7032 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5770 = torch.aten.view %5768, %5769 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5770, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7033 = torch.constant.int -2 - %int-1_7034 = torch.constant.int -1 - %5771 = torch.aten.transpose.int %248, %int-2_7033, %int-1_7034 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7035 = torch.constant.int 4 - %5772 = torch.aten.mul.int %int4_7035, %5666 : !torch.int, !torch.int -> !torch.int - %int4096_7036 = torch.constant.int 4096 - %5773 = torch.prim.ListConstruct %5772, %int4096_7036 : (!torch.int, !torch.int) -> !torch.list - %5774 = torch.aten.view %5770, %5773 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5774, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5775 = torch.aten.mm %5774, %5771 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5775, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_7037 = torch.constant.int 4 - %int4096_7038 = torch.constant.int 4096 - %5776 = torch.prim.ListConstruct %int4_7037, %5666, %int4096_7038 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5777 = torch.aten.view %5775, %5776 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5777, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_7039 = torch.constant.int 1 - %5778 = torch.aten.add.Tensor %5616, %5777, %int1_7039 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5778, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_7040 = torch.constant.int 6 - %5779 = torch.prims.convert_element_type %5778, %int6_7040 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5779, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_7041 = torch.constant.int 2 - %5780 = torch.aten.pow.Tensor_Scalar %5779, %int2_7041 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5780, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_7042 = torch.constant.int -1 - %5781 = torch.prim.ListConstruct %int-1_7042 : (!torch.int) -> !torch.list - %true_7043 = torch.constant.bool true - %none_7044 = torch.constant.none - %5782 = torch.aten.mean.dim %5780, %5781, %true_7043, %none_7044 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5782, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_7045 = torch.constant.float 9.9999997473787516E-6 - %int1_7046 = torch.constant.int 1 - %5783 = torch.aten.add.Scalar %5782, %float9.999990e-06_7045, %int1_7046 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5783, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5784 = torch.aten.rsqrt %5783 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5784, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5785 = torch.aten.mul.Tensor %5779, %5784 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5785, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7047 = torch.constant.int 5 - %5786 = torch.prims.convert_element_type %5785, %int5_7047 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5786, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %5787 = torch.aten.mul.Tensor %249, %5786 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5787, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7048 = torch.constant.int 5 - %5788 = torch.prims.convert_element_type %5787, %int5_7048 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5788, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7049 = torch.constant.int -2 - %int-1_7050 = torch.constant.int -1 - %5789 = torch.aten.transpose.int %250, %int-2_7049, %int-1_7050 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_7051 = torch.constant.int 4 - %5790 = torch.aten.mul.int %int4_7051, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7052 = torch.constant.int 4096 - %5791 = torch.prim.ListConstruct %5790, %int4096_7052 : (!torch.int, !torch.int) -> !torch.list - %5792 = torch.aten.view %5788, %5791 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5792, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5793 = torch.aten.mm %5792, %5789 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5793, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_7053 = torch.constant.int 4 - %int14336_7054 = torch.constant.int 14336 - %5794 = torch.prim.ListConstruct %int4_7053, %306, %int14336_7054 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5795 = torch.aten.view %5793, %5794 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5795, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %5796 = torch.aten.silu %5795 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5796, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_7055 = torch.constant.int -2 - %int-1_7056 = torch.constant.int -1 - %5797 = torch.aten.transpose.int %251, %int-2_7055, %int-1_7056 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_7057 = torch.constant.int 4 - %5798 = torch.aten.mul.int %int4_7057, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7058 = torch.constant.int 4096 - %5799 = torch.prim.ListConstruct %5798, %int4096_7058 : (!torch.int, !torch.int) -> !torch.list - %5800 = torch.aten.view %5788, %5799 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5800, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5801 = torch.aten.mm %5800, %5797 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5801, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_7059 = torch.constant.int 4 - %int14336_7060 = torch.constant.int 14336 - %5802 = torch.prim.ListConstruct %int4_7059, %306, %int14336_7060 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5803 = torch.aten.view %5801, %5802 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5803, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %5804 = torch.aten.mul.Tensor %5796, %5803 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5804, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_7061 = torch.constant.int -2 - %int-1_7062 = torch.constant.int -1 - %5805 = torch.aten.transpose.int %252, %int-2_7061, %int-1_7062 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %5788 = torch.aten.slice.Tensor %5726, %int3_7026, %int64_7027, %int9223372036854775807_7028, %int1_7029 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %5788, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %5789 = torch.aten.neg %5788 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %5789, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %5790 = torch.prim.ListConstruct %5789, %5787 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_7030 = torch.constant.int -1 + %5791 = torch.aten.cat %5790, %int-1_7030 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5791, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %5792 = torch.aten.mul.Tensor %5791, %5785 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5792, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_7031 = torch.constant.int 1 + %5793 = torch.aten.add.Tensor %5786, %5792, %int1_7031 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5793, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_7032 = torch.constant.int 131072 + %none_7033 = torch.constant.none + %none_7034 = torch.constant.none + %cpu_7035 = torch.constant.device "cpu" + %false_7036 = torch.constant.bool false + %5794 = torch.aten.arange %int131072_7032, %none_7033, %none_7034, %cpu_7035, %false_7036 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_7037 = torch.constant.int 0 + %int128_7038 = torch.constant.int 128 + %int2_7039 = torch.constant.int 2 + %int4_7040 = torch.constant.int 4 + %none_7041 = torch.constant.none + %cpu_7042 = torch.constant.device "cpu" + %false_7043 = torch.constant.bool false + %5795 = torch.aten.arange.start_step %int0_7037, %int128_7038, %int2_7039, %int4_7040, %none_7041, %cpu_7042, %false_7043 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_7044 = torch.constant.int 6 + %5796 = torch.prims.convert_element_type %5795, %int6_7044 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_7045 = torch.constant.int 128 + %5797 = torch.aten.div.Scalar %5796, %int128_7045 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_7046 = torch.constant.float 5.000000e+05 + %5798 = torch.aten.pow.Scalar %float5.000000e05_7046, %5797 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5799 = torch.aten.reciprocal %5798 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_7047 = torch.constant.float 1.000000e+00 + %5800 = torch.aten.mul.Scalar %5799, %float1.000000e00_7047 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %5801 = torch.aten.reciprocal %5800 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_7048 = torch.constant.float 6.2831853071795862 + %5802 = torch.aten.mul.Scalar %5801, %float6.283190e00_7048 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_7049 = torch.constant.float 8.192000e+03 + %5803 = torch.aten.gt.Scalar %5802, %float8.192000e03_7049 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_7050 = torch.constant.int 8 + %5804 = torch.aten.div.Scalar %5800, %int8_7050 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %5805 = torch.aten.where.self %5803, %5804, %5800 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5806 = torch.aten.reciprocal %5802 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_7051 = torch.constant.int 8192 + %5807 = torch.aten.mul.Scalar %5806, %int8192_7051 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_7052 = torch.constant.int 1 + %int1_7053 = torch.constant.int 1 + %5808 = torch.aten.sub.Scalar %5807, %int1_7052, %int1_7053 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_7054 = torch.constant.int 3 + %5809 = torch.aten.div.Scalar %5808, %int3_7054 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_7055 = torch.constant.int 1 + %int1_7056 = torch.constant.int 1 + %5810 = torch.aten.rsub.Scalar %5809, %int1_7055, %int1_7056 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %5811 = torch.aten.mul.Tensor %5810, %5805 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_7057 = torch.constant.int 8 + %5812 = torch.aten.div.Scalar %5811, %int8_7057 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %5813 = torch.aten.mul.Tensor %5809, %5805 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_7058 = torch.constant.int 1 + %5814 = torch.aten.add.Tensor %5812, %5813, %int1_7058 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_7059 = torch.constant.float 2.048000e+03 + %5815 = torch.aten.lt.Scalar %5802, %float2.048000e03_7059 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %5816 = torch.aten.bitwise_not %5815 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_7060 = torch.constant.float 8.192000e+03 + %5817 = torch.aten.gt.Scalar %5802, %float8.192000e03_7060 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %5818 = torch.aten.bitwise_not %5817 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %5819 = torch.aten.mul.Tensor %5816, %5818 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %5820 = torch.aten.where.self %5819, %5814, %5805 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %5821 = torch.prim.ListConstruct %5820, %5820 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_7061 = torch.constant.int -1 + %5822 = torch.aten.cat %5821, %int-1_7061 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_7062 = torch.constant.int 6 + %5823 = torch.prims.convert_element_type %5822, %int6_7062 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> %int1_7063 = torch.constant.int 1 - %5806 = torch.aten.size.int %5795, %int1_7063 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_7064 = torch.constant.int 4 - %5807 = torch.aten.mul.int %int4_7064, %5806 : !torch.int, !torch.int -> !torch.int - %int14336_7065 = torch.constant.int 14336 - %5808 = torch.prim.ListConstruct %5807, %int14336_7065 : (!torch.int, !torch.int) -> !torch.list - %5809 = torch.aten.view %5804, %5808 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5809, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %5810 = torch.aten.mm %5809, %5805 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5810, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_7066 = torch.constant.int 4 - %int4096_7067 = torch.constant.int 4096 - %5811 = torch.prim.ListConstruct %int4_7066, %5806, %int4096_7067 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5812 = torch.aten.view %5810, %5811 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5812, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_7068 = torch.constant.int 1 - %5813 = torch.aten.add.Tensor %5778, %5812, %int1_7068 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5813, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_7069 = torch.constant.int 6 - %5814 = torch.prims.convert_element_type %5813, %int6_7069 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5814, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_7070 = torch.constant.int 2 - %5815 = torch.aten.pow.Tensor_Scalar %5814, %int2_7070 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5815, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_7071 = torch.constant.int -1 - %5816 = torch.prim.ListConstruct %int-1_7071 : (!torch.int) -> !torch.list - %true_7072 = torch.constant.bool true - %none_7073 = torch.constant.none - %5817 = torch.aten.mean.dim %5815, %5816, %true_7072, %none_7073 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5817, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_7074 = torch.constant.float 9.9999997473787516E-6 + %5824 = torch.aten.unsqueeze %5794, %int1_7063 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_7064 = torch.constant.int 6 + %5825 = torch.prims.convert_element_type %5824, %int6_7064 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_7065 = torch.constant.int 0 + %5826 = torch.aten.unsqueeze %5823, %int0_7065 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_7066 = torch.constant.int 6 + %5827 = torch.prims.convert_element_type %5826, %int6_7066 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %5828 = torch.aten.mul.Tensor %5825, %5827 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %5829 = torch.aten.cos %5828 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_7067 = torch.constant.int 5 + %5830 = torch.prims.convert_element_type %5829, %int5_7067 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %5831 = torch.aten.sin %5828 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_7068 = torch.constant.int 5 + %5832 = torch.prims.convert_element_type %5831, %int5_7068 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_7069 = torch.constant.int 0 + %int0_7070 = torch.constant.int 0 + %int1_7071 = torch.constant.int 1 + %5833 = torch.aten.slice.Tensor %5830, %int0_7069, %int0_7070, %298, %int1_7071 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5833, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_7072 = torch.constant.int 1 + %int0_7073 = torch.constant.int 0 + %int9223372036854775807_7074 = torch.constant.int 9223372036854775807 %int1_7075 = torch.constant.int 1 - %5818 = torch.aten.add.Scalar %5817, %float9.999990e-06_7074, %int1_7075 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5818, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5819 = torch.aten.rsqrt %5818 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5819, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5820 = torch.aten.mul.Tensor %5814, %5819 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5820, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7076 = torch.constant.int 5 - %5821 = torch.prims.convert_element_type %5820, %int5_7076 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5821, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %5822 = torch.aten.mul.Tensor %253, %5821 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5822, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7077 = torch.constant.int 5 - %5823 = torch.prims.convert_element_type %5822, %int5_7077 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5823, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7078 = torch.constant.int -2 - %int-1_7079 = torch.constant.int -1 - %5824 = torch.aten.transpose.int %254, %int-2_7078, %int-1_7079 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7080 = torch.constant.int 4 - %5825 = torch.aten.mul.int %int4_7080, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7081 = torch.constant.int 4096 - %5826 = torch.prim.ListConstruct %5825, %int4096_7081 : (!torch.int, !torch.int) -> !torch.list - %5827 = torch.aten.view %5823, %5826 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5827, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5828 = torch.aten.mm %5827, %5824 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5828, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_7082 = torch.constant.int 4 - %int4096_7083 = torch.constant.int 4096 - %5829 = torch.prim.ListConstruct %int4_7082, %306, %int4096_7083 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5830 = torch.aten.view %5828, %5829 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5830, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7084 = torch.constant.int -2 - %int-1_7085 = torch.constant.int -1 - %5831 = torch.aten.transpose.int %255, %int-2_7084, %int-1_7085 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_7086 = torch.constant.int 4 - %5832 = torch.aten.mul.int %int4_7086, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7087 = torch.constant.int 4096 - %5833 = torch.prim.ListConstruct %5832, %int4096_7087 : (!torch.int, !torch.int) -> !torch.list - %5834 = torch.aten.view %5823, %5833 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5834, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5835 = torch.aten.mm %5834, %5831 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %5835, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_7088 = torch.constant.int 4 - %int1024_7089 = torch.constant.int 1024 - %5836 = torch.prim.ListConstruct %int4_7088, %306, %int1024_7089 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5837 = torch.aten.view %5835, %5836 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %5837, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_7090 = torch.constant.int -2 - %int-1_7091 = torch.constant.int -1 - %5838 = torch.aten.transpose.int %256, %int-2_7090, %int-1_7091 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_7092 = torch.constant.int 4 - %5839 = torch.aten.mul.int %int4_7092, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7093 = torch.constant.int 4096 - %5840 = torch.prim.ListConstruct %5839, %int4096_7093 : (!torch.int, !torch.int) -> !torch.list - %5841 = torch.aten.view %5823, %5840 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5841, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5842 = torch.aten.mm %5841, %5838 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %5842, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_7094 = torch.constant.int 4 - %int1024_7095 = torch.constant.int 1024 - %5843 = torch.prim.ListConstruct %int4_7094, %306, %int1024_7095 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5844 = torch.aten.view %5842, %5843 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %5844, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_7096 = torch.constant.int 4 - %int32_7097 = torch.constant.int 32 - %int128_7098 = torch.constant.int 128 - %5845 = torch.prim.ListConstruct %int4_7096, %306, %int32_7097, %int128_7098 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5846 = torch.aten.view %5830, %5845 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5846, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_7099 = torch.constant.int 4 - %int8_7100 = torch.constant.int 8 - %int128_7101 = torch.constant.int 128 - %5847 = torch.prim.ListConstruct %int4_7099, %306, %int8_7100, %int128_7101 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5848 = torch.aten.view %5837, %5847 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5848, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_7102 = torch.constant.int 4 - %int8_7103 = torch.constant.int 8 - %int128_7104 = torch.constant.int 128 - %5849 = torch.prim.ListConstruct %int4_7102, %306, %int8_7103, %int128_7104 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5850 = torch.aten.view %5844, %5849 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5850, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_7105 = torch.constant.int 131072 - %none_7106 = torch.constant.none - %none_7107 = torch.constant.none - %cpu_7108 = torch.constant.device "cpu" - %false_7109 = torch.constant.bool false - %5851 = torch.aten.arange %int131072_7105, %none_7106, %none_7107, %cpu_7108, %false_7109 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_7110 = torch.constant.int 0 - %int128_7111 = torch.constant.int 128 - %none_7112 = torch.constant.none - %none_7113 = torch.constant.none - %cpu_7114 = torch.constant.device "cpu" - %false_7115 = torch.constant.bool false - %5852 = torch.aten.arange.start %int0_7110, %int128_7111, %none_7112, %none_7113, %cpu_7114, %false_7115 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_7116 = torch.constant.int 2 - %5853 = torch.aten.floor_divide.Scalar %5852, %int2_7116 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_7117 = torch.constant.int 6 - %5854 = torch.prims.convert_element_type %5853, %int6_7117 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_7118 = torch.constant.int 128 - %5855 = torch.aten.div.Scalar %5854, %int128_7118 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_7119 = torch.constant.float 2.000000e+00 - %5856 = torch.aten.mul.Scalar %5855, %float2.000000e00_7119 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_7120 = torch.constant.float 5.000000e+05 - %5857 = torch.aten.pow.Scalar %float5.000000e05_7120, %5856 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %5858 = torch.aten.reciprocal %5857 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_7121 = torch.constant.float 1.000000e+00 - %5859 = torch.aten.mul.Scalar %5858, %float1.000000e00_7121 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %5834 = torch.aten.slice.Tensor %5833, %int1_7072, %int0_7073, %int9223372036854775807_7074, %int1_7075 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5834, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_7076 = torch.constant.int 0 + %int0_7077 = torch.constant.int 0 + %int1_7078 = torch.constant.int 1 + %5835 = torch.aten.slice.Tensor %5832, %int0_7076, %int0_7077, %298, %int1_7078 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5835, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_7079 = torch.constant.int 1 + %int0_7080 = torch.constant.int 0 + %int9223372036854775807_7081 = torch.constant.int 9223372036854775807 + %int1_7082 = torch.constant.int 1 + %5836 = torch.aten.slice.Tensor %5835, %int1_7079, %int0_7080, %int9223372036854775807_7081, %int1_7082 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5836, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_7083 = torch.constant.int 0 + %5837 = torch.aten.unsqueeze %5834, %int0_7083 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5837, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_7084 = torch.constant.int 1 + %int0_7085 = torch.constant.int 0 + %int9223372036854775807_7086 = torch.constant.int 9223372036854775807 + %int1_7087 = torch.constant.int 1 + %5838 = torch.aten.slice.Tensor %5837, %int1_7084, %int0_7085, %int9223372036854775807_7086, %int1_7087 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5838, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_7088 = torch.constant.int 2 + %5839 = torch.aten.unsqueeze %5838, %int2_7088 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5839, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_7089 = torch.constant.int 3 + %int0_7090 = torch.constant.int 0 + %int9223372036854775807_7091 = torch.constant.int 9223372036854775807 + %int1_7092 = torch.constant.int 1 + %5840 = torch.aten.slice.Tensor %5839, %int3_7089, %int0_7090, %int9223372036854775807_7091, %int1_7092 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5840, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_7093 = torch.constant.int 4 + %int1_7094 = torch.constant.int 1 + %int1_7095 = torch.constant.int 1 + %int1_7096 = torch.constant.int 1 + %5841 = torch.prim.ListConstruct %int4_7093, %int1_7094, %int1_7095, %int1_7096 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5842 = torch.aten.repeat %5840, %5841 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %5842, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_7097 = torch.constant.int 0 + %5843 = torch.aten.unsqueeze %5836, %int0_7097 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5843, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_7098 = torch.constant.int 1 + %int0_7099 = torch.constant.int 0 + %int9223372036854775807_7100 = torch.constant.int 9223372036854775807 + %int1_7101 = torch.constant.int 1 + %5844 = torch.aten.slice.Tensor %5843, %int1_7098, %int0_7099, %int9223372036854775807_7100, %int1_7101 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %5844, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_7102 = torch.constant.int 2 + %5845 = torch.aten.unsqueeze %5844, %int2_7102 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5845, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_7103 = torch.constant.int 3 + %int0_7104 = torch.constant.int 0 + %int9223372036854775807_7105 = torch.constant.int 9223372036854775807 + %int1_7106 = torch.constant.int 1 + %5846 = torch.aten.slice.Tensor %5845, %int3_7103, %int0_7104, %int9223372036854775807_7105, %int1_7106 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %5846, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_7107 = torch.constant.int 4 + %int1_7108 = torch.constant.int 1 + %int1_7109 = torch.constant.int 1 + %int1_7110 = torch.constant.int 1 + %5847 = torch.prim.ListConstruct %int4_7107, %int1_7108, %int1_7109, %int1_7110 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5848 = torch.aten.repeat %5846, %5847 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %5848, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %5849 = torch.aten.mul.Tensor %5728, %5842 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5849, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_7111 = torch.constant.int 3 + %int0_7112 = torch.constant.int 0 + %int64_7113 = torch.constant.int 64 + %int1_7114 = torch.constant.int 1 + %5850 = torch.aten.slice.Tensor %5728, %int3_7111, %int0_7112, %int64_7113, %int1_7114 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %5850, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_7115 = torch.constant.int 3 + %int64_7116 = torch.constant.int 64 + %int9223372036854775807_7117 = torch.constant.int 9223372036854775807 + %int1_7118 = torch.constant.int 1 + %5851 = torch.aten.slice.Tensor %5728, %int3_7115, %int64_7116, %int9223372036854775807_7117, %int1_7118 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %5851, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %5852 = torch.aten.neg %5851 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %5852, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %5853 = torch.prim.ListConstruct %5852, %5850 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_7119 = torch.constant.int -1 + %5854 = torch.aten.cat %5853, %int-1_7119 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5854, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %5855 = torch.aten.mul.Tensor %5854, %5848 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5855, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_7120 = torch.constant.int 1 + %5856 = torch.aten.add.Tensor %5849, %5855, %int1_7120 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5856, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_7121 = torch.constant.int 32 + %5857 = torch.aten.mul.Scalar %arg2, %int32_7121 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5857, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int20 = torch.constant.int 20 %int1_7122 = torch.constant.int 1 - %5860 = torch.aten.unsqueeze %5851, %int1_7122 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_7123 = torch.constant.int 0 - %5861 = torch.aten.unsqueeze %5859, %int0_7123 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %5862 = torch.aten.mul.Tensor %5860, %5861 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_7124 = torch.constant.int 1 - %5863 = torch.aten.size.int %5830, %int1_7124 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_7125 = torch.constant.int 0 - %5864 = torch.aten.add.int %int0_7125, %5863 : !torch.int, !torch.int -> !torch.int - %int0_7126 = torch.constant.int 0 - %int0_7127 = torch.constant.int 0 - %int1_7128 = torch.constant.int 1 - %5865 = torch.aten.slice.Tensor %5862, %int0_7126, %int0_7127, %5864, %int1_7128 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5865, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_7129 = torch.constant.int 1 - %int0_7130 = torch.constant.int 0 - %int9223372036854775807_7131 = torch.constant.int 9223372036854775807 - %int1_7132 = torch.constant.int 1 - %5866 = torch.aten.slice.Tensor %5865, %int1_7129, %int0_7130, %int9223372036854775807_7131, %int1_7132 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5866, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %5858 = torch.aten.add.Scalar %5857, %int20, %int1_7122 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5858, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_7123 = torch.constant.int 2 + %5859 = torch.aten.mul.Scalar %5858, %int2_7123 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5859, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_7124 = torch.constant.int 0 + %int1_7125 = torch.constant.int 1 + %5860 = torch.aten.add.Scalar %5859, %int0_7124, %int1_7125 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5860, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %5861 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %5862 = torch.aten.view %5860, %5861 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %5862, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_7126 = torch.constant.int 4 + %int32_7127 = torch.constant.int 32 + %int8_7128 = torch.constant.int 8 + %int128_7129 = torch.constant.int 128 + %5863 = torch.prim.ListConstruct %int4_7126, %296, %int32_7127, %int8_7128, %int128_7129 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5864 = torch.aten.view %5856, %5863 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5864, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_7130 = torch.constant.int 32 + %int8_7131 = torch.constant.int 8 + %int128_7132 = torch.constant.int 128 + %5865 = torch.prim.ListConstruct %504, %int32_7130, %int8_7131, %int128_7132 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5866 = torch.aten.view %5864, %5865 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %5866, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> %int1_7133 = torch.constant.int 1 - %int0_7134 = torch.constant.int 0 - %int9223372036854775807_7135 = torch.constant.int 9223372036854775807 - %int1_7136 = torch.constant.int 1 - %5867 = torch.aten.slice.Tensor %5866, %int1_7133, %int0_7134, %int9223372036854775807_7135, %int1_7136 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5867, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_7137 = torch.constant.int 0 - %5868 = torch.aten.unsqueeze %5867, %int0_7137 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5868, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_7138 = torch.constant.int 1 - %int0_7139 = torch.constant.int 0 - %int9223372036854775807_7140 = torch.constant.int 9223372036854775807 - %int1_7141 = torch.constant.int 1 - %5869 = torch.aten.slice.Tensor %5868, %int1_7138, %int0_7139, %int9223372036854775807_7140, %int1_7141 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5869, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_7142 = torch.constant.int 2 - %int0_7143 = torch.constant.int 0 - %int9223372036854775807_7144 = torch.constant.int 9223372036854775807 - %int1_7145 = torch.constant.int 1 - %5870 = torch.aten.slice.Tensor %5869, %int2_7142, %int0_7143, %int9223372036854775807_7144, %int1_7145 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5870, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_7146 = torch.constant.int 4 - %int1_7147 = torch.constant.int 1 - %int1_7148 = torch.constant.int 1 - %5871 = torch.prim.ListConstruct %int4_7146, %int1_7147, %int1_7148 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5872 = torch.aten.repeat %5870, %5871 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %5872, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_7149 = torch.constant.int 6 - %5873 = torch.prims.convert_element_type %5846, %int6_7149 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %5873, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %5874 = torch_c.to_builtin_tensor %5873 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %5875 = torch_c.to_builtin_tensor %5872 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %5876 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%5874, %5875) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %5877 = torch_c.from_builtin_tensor %5876 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %5877, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_7150 = torch.constant.int 5 - %5878 = torch.prims.convert_element_type %5877, %int5_7150 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5878, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_7151 = torch.constant.int 131072 - %none_7152 = torch.constant.none - %none_7153 = torch.constant.none - %cpu_7154 = torch.constant.device "cpu" - %false_7155 = torch.constant.bool false - %5879 = torch.aten.arange %int131072_7151, %none_7152, %none_7153, %cpu_7154, %false_7155 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_7156 = torch.constant.int 0 - %int128_7157 = torch.constant.int 128 - %none_7158 = torch.constant.none - %none_7159 = torch.constant.none - %cpu_7160 = torch.constant.device "cpu" - %false_7161 = torch.constant.bool false - %5880 = torch.aten.arange.start %int0_7156, %int128_7157, %none_7158, %none_7159, %cpu_7160, %false_7161 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> + %int2_7134 = torch.constant.int 2 + %5867 = torch.aten.transpose.int %5866, %int1_7133, %int2_7134 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5867, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_7135 = torch.constant.int 5 + %5868 = torch.prims.convert_element_type %5867, %int5_7135 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5868, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_7136 = torch.constant.int 32 + %int2_7137 = torch.constant.int 2 + %int8_7138 = torch.constant.int 8 + %int32_7139 = torch.constant.int 32 + %int128_7140 = torch.constant.int 128 + %5869 = torch.prim.ListConstruct %297, %int32_7136, %int2_7137, %int8_7138, %int32_7139, %int128_7140 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5870 = torch.aten.view %5632, %5869 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5870, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_7141 = torch.constant.int 8 + %int32_7142 = torch.constant.int 32 + %int128_7143 = torch.constant.int 128 + %5871 = torch.prim.ListConstruct %497, %int8_7141, %int32_7142, %int128_7143 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5872 = torch.aten.view %5870, %5871 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5872, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %5873 = torch.prim.ListConstruct %5862 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_7144 = torch.constant.bool false + %5874 = torch.aten.index_put %5872, %5873, %5868, %false_7144 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5874, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_7145 = torch.constant.int 32 + %int2_7146 = torch.constant.int 2 + %int8_7147 = torch.constant.int 8 + %int32_7148 = torch.constant.int 32 + %int128_7149 = torch.constant.int 128 + %5875 = torch.prim.ListConstruct %297, %int32_7145, %int2_7146, %int8_7147, %int32_7148, %int128_7149 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5876 = torch.aten.view %5874, %5875 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5876, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_7150 = torch.constant.int 2097152 + %5877 = torch.prim.ListConstruct %297, %int2097152_7150 : (!torch.int, !torch.int) -> !torch.list + %5878 = torch.aten.view %5876, %5877 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5878, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_7151 = torch.constant.int 32 + %int2_7152 = torch.constant.int 2 + %int8_7153 = torch.constant.int 8 + %int32_7154 = torch.constant.int 32 + %int128_7155 = torch.constant.int 128 + %5879 = torch.prim.ListConstruct %297, %int32_7151, %int2_7152, %int8_7153, %int32_7154, %int128_7155 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5880 = torch.aten.view %5878, %5879 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5880, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_7156 = torch.constant.int 8 + %int32_7157 = torch.constant.int 32 + %int128_7158 = torch.constant.int 128 + %5881 = torch.prim.ListConstruct %497, %int8_7156, %int32_7157, %int128_7158 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5882 = torch.aten.view %5880, %5881 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5882, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_7159 = torch.constant.int 32 + %5883 = torch.aten.mul.Scalar %arg2, %int32_7159 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5883, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int20_7160 = torch.constant.int 20 + %int1_7161 = torch.constant.int 1 + %5884 = torch.aten.add.Scalar %5883, %int20_7160, %int1_7161 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5884, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> %int2_7162 = torch.constant.int 2 - %5881 = torch.aten.floor_divide.Scalar %5880, %int2_7162 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_7163 = torch.constant.int 6 - %5882 = torch.prims.convert_element_type %5881, %int6_7163 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_7164 = torch.constant.int 128 - %5883 = torch.aten.div.Scalar %5882, %int128_7164 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_7165 = torch.constant.float 2.000000e+00 - %5884 = torch.aten.mul.Scalar %5883, %float2.000000e00_7165 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_7166 = torch.constant.float 5.000000e+05 - %5885 = torch.aten.pow.Scalar %float5.000000e05_7166, %5884 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %5886 = torch.aten.reciprocal %5885 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_7167 = torch.constant.float 1.000000e+00 - %5887 = torch.aten.mul.Scalar %5886, %float1.000000e00_7167 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_7168 = torch.constant.int 1 - %5888 = torch.aten.unsqueeze %5879, %int1_7168 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_7169 = torch.constant.int 0 - %5889 = torch.aten.unsqueeze %5887, %int0_7169 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %5890 = torch.aten.mul.Tensor %5888, %5889 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_7170 = torch.constant.int 1 - %5891 = torch.aten.size.int %5837, %int1_7170 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_7171 = torch.constant.int 0 - %5892 = torch.aten.add.int %int0_7171, %5891 : !torch.int, !torch.int -> !torch.int - %int0_7172 = torch.constant.int 0 - %int0_7173 = torch.constant.int 0 - %int1_7174 = torch.constant.int 1 - %5893 = torch.aten.slice.Tensor %5890, %int0_7172, %int0_7173, %5892, %int1_7174 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5893, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_7175 = torch.constant.int 1 - %int0_7176 = torch.constant.int 0 - %int9223372036854775807_7177 = torch.constant.int 9223372036854775807 - %int1_7178 = torch.constant.int 1 - %5894 = torch.aten.slice.Tensor %5893, %int1_7175, %int0_7176, %int9223372036854775807_7177, %int1_7178 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5894, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_7179 = torch.constant.int 1 - %int0_7180 = torch.constant.int 0 - %int9223372036854775807_7181 = torch.constant.int 9223372036854775807 - %int1_7182 = torch.constant.int 1 - %5895 = torch.aten.slice.Tensor %5894, %int1_7179, %int0_7180, %int9223372036854775807_7181, %int1_7182 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %5895, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_7183 = torch.constant.int 0 - %5896 = torch.aten.unsqueeze %5895, %int0_7183 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5896, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_7184 = torch.constant.int 1 - %int0_7185 = torch.constant.int 0 - %int9223372036854775807_7186 = torch.constant.int 9223372036854775807 - %int1_7187 = torch.constant.int 1 - %5897 = torch.aten.slice.Tensor %5896, %int1_7184, %int0_7185, %int9223372036854775807_7186, %int1_7187 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5897, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_7188 = torch.constant.int 2 - %int0_7189 = torch.constant.int 0 - %int9223372036854775807_7190 = torch.constant.int 9223372036854775807 - %int1_7191 = torch.constant.int 1 - %5898 = torch.aten.slice.Tensor %5897, %int2_7188, %int0_7189, %int9223372036854775807_7190, %int1_7191 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %5898, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_7192 = torch.constant.int 4 - %int1_7193 = torch.constant.int 1 - %int1_7194 = torch.constant.int 1 - %5899 = torch.prim.ListConstruct %int4_7192, %int1_7193, %int1_7194 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5900 = torch.aten.repeat %5898, %5899 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %5900, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_7195 = torch.constant.int 6 - %5901 = torch.prims.convert_element_type %5848, %int6_7195 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %5901, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %5902 = torch_c.to_builtin_tensor %5901 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %5903 = torch_c.to_builtin_tensor %5900 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %5904 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%5902, %5903) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %5905 = torch_c.from_builtin_tensor %5904 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %5905, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_7196 = torch.constant.int 5 - %5906 = torch.prims.convert_element_type %5905, %int5_7196 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5906, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_7197 = torch.constant.int 64 - %5907 = torch.aten.mul.Scalar %arg2, %int64_7197 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5907, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int56 = torch.constant.int 56 - %int1_7198 = torch.constant.int 1 - %5908 = torch.aten.add.Scalar %5907, %int56, %int1_7198 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5908, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %5885 = torch.aten.mul.Scalar %5884, %int2_7162 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5885, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_7163 = torch.constant.int 1 + %int1_7164 = torch.constant.int 1 + %5886 = torch.aten.add.Scalar %5885, %int1_7163, %int1_7164 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %5886, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %5887 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %5888 = torch.aten.view %5886, %5887 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %5888, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_7165 = torch.constant.int 4 + %int32_7166 = torch.constant.int 32 + %int8_7167 = torch.constant.int 8 + %int128_7168 = torch.constant.int 128 + %5889 = torch.prim.ListConstruct %int4_7165, %296, %int32_7166, %int8_7167, %int128_7168 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5890 = torch.aten.view %5730, %5889 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5890, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_7169 = torch.constant.int 32 + %int8_7170 = torch.constant.int 8 + %int128_7171 = torch.constant.int 128 + %5891 = torch.prim.ListConstruct %504, %int32_7169, %int8_7170, %int128_7171 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5892 = torch.aten.view %5890, %5891 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %5892, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_7172 = torch.constant.int 1 + %int2_7173 = torch.constant.int 2 + %5893 = torch.aten.transpose.int %5892, %int1_7172, %int2_7173 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5893, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_7174 = torch.constant.int 5 + %5894 = torch.prims.convert_element_type %5893, %int5_7174 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5894, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %5895 = torch.prim.ListConstruct %5888 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_7175 = torch.constant.bool false + %5896 = torch.aten.index_put %5882, %5895, %5894, %false_7175 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %5896, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_7176 = torch.constant.int 32 + %int2_7177 = torch.constant.int 2 + %int8_7178 = torch.constant.int 8 + %int32_7179 = torch.constant.int 32 + %int128_7180 = torch.constant.int 128 + %5897 = torch.prim.ListConstruct %297, %int32_7176, %int2_7177, %int8_7178, %int32_7179, %int128_7180 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5898 = torch.aten.view %5896, %5897 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5898, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_7181 = torch.constant.int 2097152 + %5899 = torch.prim.ListConstruct %297, %int2097152_7181 : (!torch.int, !torch.int) -> !torch.list + %5900 = torch.aten.view %5898, %5899 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5900, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_7182 = torch.constant.int -2 + %5901 = torch.aten.unsqueeze %5856, %int-2_7182 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5901, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_7183 = torch.constant.int 4 + %int8_7184 = torch.constant.int 8 + %int4_7185 = torch.constant.int 4 + %int128_7186 = torch.constant.int 128 + %5902 = torch.prim.ListConstruct %int4_7183, %298, %int8_7184, %int4_7185, %int128_7186 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_7187 = torch.constant.bool false + %5903 = torch.aten.expand %5901, %5902, %false_7187 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5903, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_7188 = torch.constant.int 0 + %5904 = torch.aten.clone %5903, %int0_7188 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5904, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_7189 = torch.constant.int 4 + %int32_7190 = torch.constant.int 32 + %int128_7191 = torch.constant.int 128 + %5905 = torch.prim.ListConstruct %int4_7189, %298, %int32_7190, %int128_7191 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5906 = torch.aten._unsafe_view %5904, %5905 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5906, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_7192 = torch.constant.int -2 + %5907 = torch.aten.unsqueeze %5730, %int-2_7192 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5907, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_7193 = torch.constant.int 4 + %int8_7194 = torch.constant.int 8 + %int4_7195 = torch.constant.int 4 + %int128_7196 = torch.constant.int 128 + %5908 = torch.prim.ListConstruct %int4_7193, %298, %int8_7194, %int4_7195, %int128_7196 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_7197 = torch.constant.bool false + %5909 = torch.aten.expand %5907, %5908, %false_7197 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5909, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_7198 = torch.constant.int 0 + %5910 = torch.aten.clone %5909, %int0_7198 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5910, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_7199 = torch.constant.int 4 %int32_7200 = torch.constant.int 32 - %int8_7201 = torch.constant.int 8 - %int128_7202 = torch.constant.int 128 - %5909 = torch.prim.ListConstruct %int4_7199, %398, %int32_7200, %int8_7201, %int128_7202 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5910 = torch.aten.view %5906, %5909 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5910, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_7203 = torch.constant.int 4 - %5911 = torch.aten.mul.int %int4_7203, %398 : !torch.int, !torch.int -> !torch.int - %int32_7204 = torch.constant.int 32 - %int8_7205 = torch.constant.int 8 - %int128_7206 = torch.constant.int 128 - %5912 = torch.prim.ListConstruct %5911, %int32_7204, %int8_7205, %int128_7206 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5913 = torch.aten.view %5910, %5912 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5913, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_7207 = torch.constant.int 4 - %5914 = torch.aten.mul.int %int4_7207, %398 : !torch.int, !torch.int -> !torch.int - %5915 = torch.prim.ListConstruct %5914 : (!torch.int) -> !torch.list - %5916 = torch.aten.view %5908, %5915 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %5916, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_7208 = torch.constant.int 32 - %int2_7209 = torch.constant.int 2 - %int32_7210 = torch.constant.int 32 - %int8_7211 = torch.constant.int 8 - %int128_7212 = torch.constant.int 128 - %5917 = torch.prim.ListConstruct %389, %int32_7208, %int2_7209, %int32_7210, %int8_7211, %int128_7212 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5918 = torch.aten.view %5750, %5917 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5918, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_7213 = torch.constant.int 32 - %5919 = torch.aten.mul.int %389, %int32_7213 : !torch.int, !torch.int -> !torch.int - %int2_7214 = torch.constant.int 2 - %5920 = torch.aten.mul.int %5919, %int2_7214 : !torch.int, !torch.int -> !torch.int - %int32_7215 = torch.constant.int 32 - %int8_7216 = torch.constant.int 8 - %int128_7217 = torch.constant.int 128 - %5921 = torch.prim.ListConstruct %5920, %int32_7215, %int8_7216, %int128_7217 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5922 = torch.aten.view %5918, %5921 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5922, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %5923 = torch.prim.ListConstruct %5916 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_7218 = torch.constant.bool false - %5924 = torch.aten.index_put %5922, %5923, %5913, %false_7218 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5924, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_7219 = torch.constant.int 32 - %int2_7220 = torch.constant.int 2 - %int32_7221 = torch.constant.int 32 - %int8_7222 = torch.constant.int 8 - %int128_7223 = torch.constant.int 128 - %5925 = torch.prim.ListConstruct %389, %int32_7219, %int2_7220, %int32_7221, %int8_7222, %int128_7223 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5926 = torch.aten.view %5924, %5925 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5926, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_7224 = torch.constant.int 2097152 - %5927 = torch.prim.ListConstruct %389, %int2097152_7224 : (!torch.int, !torch.int) -> !torch.list - %5928 = torch.aten.view %5926, %5927 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5928, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_7225 = torch.constant.int 32 - %int2_7226 = torch.constant.int 2 - %int32_7227 = torch.constant.int 32 - %int8_7228 = torch.constant.int 8 - %int128_7229 = torch.constant.int 128 - %5929 = torch.prim.ListConstruct %389, %int32_7225, %int2_7226, %int32_7227, %int8_7228, %int128_7229 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5930 = torch.aten.view %5928, %5929 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5930, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_7230 = torch.constant.int 32 - %int8_7231 = torch.constant.int 8 - %int128_7232 = torch.constant.int 128 - %5931 = torch.prim.ListConstruct %5920, %int32_7230, %int8_7231, %int128_7232 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5932 = torch.aten.view %5930, %5931 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5932, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_7233 = torch.constant.int 4 - %int32_7234 = torch.constant.int 32 - %int8_7235 = torch.constant.int 8 - %int128_7236 = torch.constant.int 128 - %5933 = torch.prim.ListConstruct %int4_7233, %398, %int32_7234, %int8_7235, %int128_7236 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5934 = torch.aten.view %5850, %5933 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5934, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_7237 = torch.constant.int 4 - %5935 = torch.aten.mul.int %int4_7237, %398 : !torch.int, !torch.int -> !torch.int - %int32_7238 = torch.constant.int 32 - %int8_7239 = torch.constant.int 8 - %int128_7240 = torch.constant.int 128 - %5936 = torch.prim.ListConstruct %5935, %int32_7238, %int8_7239, %int128_7240 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5937 = torch.aten.view %5934, %5936 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5937, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_7241 = torch.constant.int 1 - %int1_7242 = torch.constant.int 1 - %5938 = torch.aten.add.Scalar %5908, %int1_7241, %int1_7242 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5938, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_7243 = torch.constant.int 4 - %5939 = torch.aten.mul.int %int4_7243, %398 : !torch.int, !torch.int -> !torch.int - %5940 = torch.prim.ListConstruct %5939 : (!torch.int) -> !torch.list - %5941 = torch.aten.view %5938, %5940 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %5941, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %5942 = torch.prim.ListConstruct %5941 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_7244 = torch.constant.bool false - %5943 = torch.aten.index_put %5932, %5942, %5937, %false_7244 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %5943, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_7245 = torch.constant.int 32 - %int2_7246 = torch.constant.int 2 - %int32_7247 = torch.constant.int 32 - %int8_7248 = torch.constant.int 8 - %int128_7249 = torch.constant.int 128 - %5944 = torch.prim.ListConstruct %389, %int32_7245, %int2_7246, %int32_7247, %int8_7248, %int128_7249 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5945 = torch.aten.view %5943, %5944 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5945, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_7250 = torch.constant.int 2097152 - %5946 = torch.prim.ListConstruct %389, %int2097152_7250 : (!torch.int, !torch.int) -> !torch.list - %5947 = torch.aten.view %5945, %5946 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5947, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_7251 = torch.constant.int -2 - %5948 = torch.aten.unsqueeze %5906, %int-2_7251 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5948, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_7252 = torch.constant.int 4 - %int8_7253 = torch.constant.int 8 - %int4_7254 = torch.constant.int 4 - %int128_7255 = torch.constant.int 128 - %5949 = torch.prim.ListConstruct %int4_7252, %5891, %int8_7253, %int4_7254, %int128_7255 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7256 = torch.constant.bool false - %5950 = torch.aten.expand %5948, %5949, %false_7256 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5950, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_7257 = torch.constant.int 0 - %5951 = torch.aten.clone %5950, %int0_7257 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5951, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7258 = torch.constant.int 4 - %int32_7259 = torch.constant.int 32 - %int128_7260 = torch.constant.int 128 - %5952 = torch.prim.ListConstruct %int4_7258, %5891, %int32_7259, %int128_7260 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5953 = torch.aten._unsafe_view %5951, %5952 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5953, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_7261 = torch.constant.int -2 - %5954 = torch.aten.unsqueeze %5850, %int-2_7261 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5954, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_7262 = torch.constant.int 1 - %5955 = torch.aten.size.int %5844, %int1_7262 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int + %int128_7201 = torch.constant.int 128 + %5911 = torch.prim.ListConstruct %int4_7199, %298, %int32_7200, %int128_7201 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5912 = torch.aten._unsafe_view %5910, %5911 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5912, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_7202 = torch.constant.int 1 + %int2_7203 = torch.constant.int 2 + %5913 = torch.aten.transpose.int %5793, %int1_7202, %int2_7203 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5913, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_7204 = torch.constant.int 1 + %int2_7205 = torch.constant.int 2 + %5914 = torch.aten.transpose.int %5906, %int1_7204, %int2_7205 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5914, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_7206 = torch.constant.int 1 + %int2_7207 = torch.constant.int 2 + %5915 = torch.aten.transpose.int %5912, %int1_7206, %int2_7207 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5915, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_7208 = torch.constant.float 0.000000e+00 + %false_7209 = torch.constant.bool false + %none_7210 = torch.constant.none + %5916:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5913, %5914, %5915, %float0.000000e00_7208, %false_7209, %327, %none_7210) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %5916#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_7211 = torch.constant.int 1 + %int2_7212 = torch.constant.int 2 + %5917 = torch.aten.transpose.int %5916#0, %int1_7211, %int2_7212 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5917, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_7213 = torch.constant.int 4 + %int4096_7214 = torch.constant.int 4096 + %5918 = torch.prim.ListConstruct %int4_7213, %298, %int4096_7214 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5919 = torch.aten.view %5917, %5918 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5919, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_7215 = torch.constant.int -2 + %int-1_7216 = torch.constant.int -1 + %5920 = torch.aten.transpose.int %186, %int-2_7215, %int-1_7216 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_7217 = torch.constant.int 5 + %5921 = torch.prims.convert_element_type %5920, %int5_7217 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_7218 = torch.constant.int 4096 + %5922 = torch.prim.ListConstruct %342, %int4096_7218 : (!torch.int, !torch.int) -> !torch.list + %5923 = torch.aten.view %5919, %5922 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5923, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5924 = torch.aten.mm %5923, %5921 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5924, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_7219 = torch.constant.int 4 + %int4096_7220 = torch.constant.int 4096 + %5925 = torch.prim.ListConstruct %int4_7219, %298, %int4096_7220 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5926 = torch.aten.view %5924, %5925 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5926, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_7221 = torch.constant.int 1 + %5927 = torch.aten.add.Tensor %5693, %5926, %int1_7221 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5927, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_7222 = torch.constant.int 6 + %5928 = torch.prims.convert_element_type %5927, %int6_7222 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5928, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_7223 = torch.constant.int 2 + %5929 = torch.aten.pow.Tensor_Scalar %5928, %int2_7223 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5929, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_7224 = torch.constant.int -1 + %5930 = torch.prim.ListConstruct %int-1_7224 : (!torch.int) -> !torch.list + %true_7225 = torch.constant.bool true + %none_7226 = torch.constant.none + %5931 = torch.aten.mean.dim %5929, %5930, %true_7225, %none_7226 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5931, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_7227 = torch.constant.float 9.9999997473787516E-6 + %int1_7228 = torch.constant.int 1 + %5932 = torch.aten.add.Scalar %5931, %float9.999990e-06_7227, %int1_7228 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5932, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5933 = torch.aten.rsqrt %5932 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5933, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5934 = torch.aten.mul.Tensor %5928, %5933 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5934, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_7229 = torch.constant.int 5 + %5935 = torch.prims.convert_element_type %5934, %int5_7229 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5935, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %5936 = torch.aten.mul.Tensor %187, %5935 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5936, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_7230 = torch.constant.int 5 + %5937 = torch.prims.convert_element_type %5936, %int5_7230 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5937, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_7231 = torch.constant.int -2 + %int-1_7232 = torch.constant.int -1 + %5938 = torch.aten.transpose.int %188, %int-2_7231, %int-1_7232 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_7233 = torch.constant.int 5 + %5939 = torch.prims.convert_element_type %5938, %int5_7233 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_7234 = torch.constant.int 4096 + %5940 = torch.prim.ListConstruct %342, %int4096_7234 : (!torch.int, !torch.int) -> !torch.list + %5941 = torch.aten.view %5937, %5940 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5941, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5942 = torch.aten.mm %5941, %5939 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %5942, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_7235 = torch.constant.int 4 + %int14336_7236 = torch.constant.int 14336 + %5943 = torch.prim.ListConstruct %int4_7235, %298, %int14336_7236 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5944 = torch.aten.view %5942, %5943 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5944, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %5945 = torch.aten.silu %5944 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5945, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_7237 = torch.constant.int -2 + %int-1_7238 = torch.constant.int -1 + %5946 = torch.aten.transpose.int %189, %int-2_7237, %int-1_7238 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_7239 = torch.constant.int 5 + %5947 = torch.prims.convert_element_type %5946, %int5_7239 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_7240 = torch.constant.int 4096 + %5948 = torch.prim.ListConstruct %342, %int4096_7240 : (!torch.int, !torch.int) -> !torch.list + %5949 = torch.aten.view %5937, %5948 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5949, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5950 = torch.aten.mm %5949, %5947 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %5950, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_7241 = torch.constant.int 4 + %int14336_7242 = torch.constant.int 14336 + %5951 = torch.prim.ListConstruct %int4_7241, %298, %int14336_7242 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5952 = torch.aten.view %5950, %5951 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5952, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %5953 = torch.aten.mul.Tensor %5945, %5952 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %5953, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_7243 = torch.constant.int -2 + %int-1_7244 = torch.constant.int -1 + %5954 = torch.aten.transpose.int %190, %int-2_7243, %int-1_7244 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_7245 = torch.constant.int 5 + %5955 = torch.prims.convert_element_type %5954, %int5_7245 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_7246 = torch.constant.int 14336 + %5956 = torch.prim.ListConstruct %342, %int14336_7246 : (!torch.int, !torch.int) -> !torch.list + %5957 = torch.aten.view %5953, %5956 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %5957, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %5958 = torch.aten.mm %5957, %5955 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5958, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_7247 = torch.constant.int 4 + %int4096_7248 = torch.constant.int 4096 + %5959 = torch.prim.ListConstruct %int4_7247, %298, %int4096_7248 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5960 = torch.aten.view %5958, %5959 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5960, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_7249 = torch.constant.int 1 + %5961 = torch.aten.add.Tensor %5927, %5960, %int1_7249 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5961, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_7250 = torch.constant.int 6 + %5962 = torch.prims.convert_element_type %5961, %int6_7250 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5962, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_7251 = torch.constant.int 2 + %5963 = torch.aten.pow.Tensor_Scalar %5962, %int2_7251 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5963, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_7252 = torch.constant.int -1 + %5964 = torch.prim.ListConstruct %int-1_7252 : (!torch.int) -> !torch.list + %true_7253 = torch.constant.bool true + %none_7254 = torch.constant.none + %5965 = torch.aten.mean.dim %5963, %5964, %true_7253, %none_7254 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5965, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_7255 = torch.constant.float 9.9999997473787516E-6 + %int1_7256 = torch.constant.int 1 + %5966 = torch.aten.add.Scalar %5965, %float9.999990e-06_7255, %int1_7256 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5966, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5967 = torch.aten.rsqrt %5966 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %5967, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %5968 = torch.aten.mul.Tensor %5962, %5967 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5968, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_7257 = torch.constant.int 5 + %5969 = torch.prims.convert_element_type %5968, %int5_7257 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5969, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %5970 = torch.aten.mul.Tensor %191, %5969 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %5970, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_7258 = torch.constant.int 5 + %5971 = torch.prims.convert_element_type %5970, %int5_7258 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5971, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_7259 = torch.constant.int -2 + %int-1_7260 = torch.constant.int -1 + %5972 = torch.aten.transpose.int %192, %int-2_7259, %int-1_7260 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_7261 = torch.constant.int 5 + %5973 = torch.prims.convert_element_type %5972, %int5_7261 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_7262 = torch.constant.int 4096 + %5974 = torch.prim.ListConstruct %342, %int4096_7262 : (!torch.int, !torch.int) -> !torch.list + %5975 = torch.aten.view %5971, %5974 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5975, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5976 = torch.aten.mm %5975, %5973 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5976, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> %int4_7263 = torch.constant.int 4 - %int8_7264 = torch.constant.int 8 - %int4_7265 = torch.constant.int 4 - %int128_7266 = torch.constant.int 128 - %5956 = torch.prim.ListConstruct %int4_7263, %5955, %int8_7264, %int4_7265, %int128_7266 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7267 = torch.constant.bool false - %5957 = torch.aten.expand %5954, %5956, %false_7267 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5957, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_7268 = torch.constant.int 0 - %5958 = torch.aten.clone %5957, %int0_7268 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5958, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4096_7264 = torch.constant.int 4096 + %5977 = torch.prim.ListConstruct %int4_7263, %298, %int4096_7264 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5978 = torch.aten.view %5976, %5977 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %5978, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_7265 = torch.constant.int -2 + %int-1_7266 = torch.constant.int -1 + %5979 = torch.aten.transpose.int %193, %int-2_7265, %int-1_7266 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_7267 = torch.constant.int 5 + %5980 = torch.prims.convert_element_type %5979, %int5_7267 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_7268 = torch.constant.int 4096 + %5981 = torch.prim.ListConstruct %342, %int4096_7268 : (!torch.int, !torch.int) -> !torch.list + %5982 = torch.aten.view %5971, %5981 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5982, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5983 = torch.aten.mm %5982, %5980 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %5983, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> %int4_7269 = torch.constant.int 4 - %int32_7270 = torch.constant.int 32 - %int128_7271 = torch.constant.int 128 - %5959 = torch.prim.ListConstruct %int4_7269, %5955, %int32_7270, %int128_7271 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5960 = torch.aten._unsafe_view %5958, %5959 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5960, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_7272 = torch.constant.int 1 - %int2_7273 = torch.constant.int 2 - %5961 = torch.aten.transpose.int %5878, %int1_7272, %int2_7273 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5961, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_7274 = torch.constant.int 1 - %int2_7275 = torch.constant.int 2 - %5962 = torch.aten.transpose.int %5953, %int1_7274, %int2_7275 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5962, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_7276 = torch.constant.int 1 - %int2_7277 = torch.constant.int 2 - %5963 = torch.aten.transpose.int %5960, %int1_7276, %int2_7277 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5963, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_7278 = torch.constant.float 0.000000e+00 - %true_7279 = torch.constant.bool true - %none_7280 = torch.constant.none - %none_7281 = torch.constant.none - %5964:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5961, %5962, %5963, %float0.000000e00_7278, %true_7279, %none_7280, %none_7281) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %5964#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_7282 = torch.constant.int 1 - %int2_7283 = torch.constant.int 2 - %5965 = torch.aten.transpose.int %5964#0, %int1_7282, %int2_7283 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5965, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_7284 = torch.constant.int 4 - %int4096_7285 = torch.constant.int 4096 - %5966 = torch.prim.ListConstruct %int4_7284, %5863, %int4096_7285 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5967 = torch.aten.view %5965, %5966 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5967, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7286 = torch.constant.int -2 - %int-1_7287 = torch.constant.int -1 - %5968 = torch.aten.transpose.int %257, %int-2_7286, %int-1_7287 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7288 = torch.constant.int 4 - %5969 = torch.aten.mul.int %int4_7288, %5863 : !torch.int, !torch.int -> !torch.int - %int4096_7289 = torch.constant.int 4096 - %5970 = torch.prim.ListConstruct %5969, %int4096_7289 : (!torch.int, !torch.int) -> !torch.list - %5971 = torch.aten.view %5967, %5970 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5971, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5972 = torch.aten.mm %5971, %5968 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5972, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_7290 = torch.constant.int 4 - %int4096_7291 = torch.constant.int 4096 - %5973 = torch.prim.ListConstruct %int4_7290, %5863, %int4096_7291 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5974 = torch.aten.view %5972, %5973 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5974, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_7292 = torch.constant.int 1 - %5975 = torch.aten.add.Tensor %5813, %5974, %int1_7292 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5975, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_7293 = torch.constant.int 6 - %5976 = torch.prims.convert_element_type %5975, %int6_7293 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5976, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_7294 = torch.constant.int 2 - %5977 = torch.aten.pow.Tensor_Scalar %5976, %int2_7294 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5977, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_7295 = torch.constant.int -1 - %5978 = torch.prim.ListConstruct %int-1_7295 : (!torch.int) -> !torch.list - %true_7296 = torch.constant.bool true - %none_7297 = torch.constant.none - %5979 = torch.aten.mean.dim %5977, %5978, %true_7296, %none_7297 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5979, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_7298 = torch.constant.float 9.9999997473787516E-6 - %int1_7299 = torch.constant.int 1 - %5980 = torch.aten.add.Scalar %5979, %float9.999990e-06_7298, %int1_7299 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5980, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5981 = torch.aten.rsqrt %5980 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %5981, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %5982 = torch.aten.mul.Tensor %5976, %5981 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5982, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7300 = torch.constant.int 5 - %5983 = torch.prims.convert_element_type %5982, %int5_7300 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5983, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %5984 = torch.aten.mul.Tensor %258, %5983 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %5984, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7301 = torch.constant.int 5 - %5985 = torch.prims.convert_element_type %5984, %int5_7301 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %5985, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7302 = torch.constant.int -2 - %int-1_7303 = torch.constant.int -1 - %5986 = torch.aten.transpose.int %259, %int-2_7302, %int-1_7303 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_7304 = torch.constant.int 4 - %5987 = torch.aten.mul.int %int4_7304, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7305 = torch.constant.int 4096 - %5988 = torch.prim.ListConstruct %5987, %int4096_7305 : (!torch.int, !torch.int) -> !torch.list - %5989 = torch.aten.view %5985, %5988 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5989, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5990 = torch.aten.mm %5989, %5986 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5990, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_7306 = torch.constant.int 4 - %int14336_7307 = torch.constant.int 14336 - %5991 = torch.prim.ListConstruct %int4_7306, %306, %int14336_7307 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5992 = torch.aten.view %5990, %5991 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5992, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %5993 = torch.aten.silu %5992 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %5993, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_7308 = torch.constant.int -2 - %int-1_7309 = torch.constant.int -1 - %5994 = torch.aten.transpose.int %260, %int-2_7308, %int-1_7309 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_7310 = torch.constant.int 4 - %5995 = torch.aten.mul.int %int4_7310, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7311 = torch.constant.int 4096 - %5996 = torch.prim.ListConstruct %5995, %int4096_7311 : (!torch.int, !torch.int) -> !torch.list - %5997 = torch.aten.view %5985, %5996 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %5997, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %5998 = torch.aten.mm %5997, %5994 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %5998, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_7312 = torch.constant.int 4 - %int14336_7313 = torch.constant.int 14336 - %5999 = torch.prim.ListConstruct %int4_7312, %306, %int14336_7313 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6000 = torch.aten.view %5998, %5999 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %6000, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %6001 = torch.aten.mul.Tensor %5993, %6000 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %6001, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_7314 = torch.constant.int -2 + %int1024_7270 = torch.constant.int 1024 + %5984 = torch.prim.ListConstruct %int4_7269, %298, %int1024_7270 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5985 = torch.aten.view %5983, %5984 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %5985, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_7271 = torch.constant.int -2 + %int-1_7272 = torch.constant.int -1 + %5986 = torch.aten.transpose.int %194, %int-2_7271, %int-1_7272 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_7273 = torch.constant.int 5 + %5987 = torch.prims.convert_element_type %5986, %int5_7273 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_7274 = torch.constant.int 4096 + %5988 = torch.prim.ListConstruct %342, %int4096_7274 : (!torch.int, !torch.int) -> !torch.list + %5989 = torch.aten.view %5971, %5988 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %5989, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %5990 = torch.aten.mm %5989, %5987 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %5990, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_7275 = torch.constant.int 4 + %int1024_7276 = torch.constant.int 1024 + %5991 = torch.prim.ListConstruct %int4_7275, %298, %int1024_7276 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5992 = torch.aten.view %5990, %5991 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %5992, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_7277 = torch.constant.int 4 + %int32_7278 = torch.constant.int 32 + %int128_7279 = torch.constant.int 128 + %5993 = torch.prim.ListConstruct %int4_7277, %298, %int32_7278, %int128_7279 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5994 = torch.aten.view %5978, %5993 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5994, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_7280 = torch.constant.int 4 + %int8_7281 = torch.constant.int 8 + %int128_7282 = torch.constant.int 128 + %5995 = torch.prim.ListConstruct %int4_7280, %298, %int8_7281, %int128_7282 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5996 = torch.aten.view %5985, %5995 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5996, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_7283 = torch.constant.int 4 + %int8_7284 = torch.constant.int 8 + %int128_7285 = torch.constant.int 128 + %5997 = torch.prim.ListConstruct %int4_7283, %298, %int8_7284, %int128_7285 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5998 = torch.aten.view %5992, %5997 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5998, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_7286 = torch.constant.int 131072 + %none_7287 = torch.constant.none + %none_7288 = torch.constant.none + %cpu_7289 = torch.constant.device "cpu" + %false_7290 = torch.constant.bool false + %5999 = torch.aten.arange %int131072_7286, %none_7287, %none_7288, %cpu_7289, %false_7290 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_7291 = torch.constant.int 0 + %int128_7292 = torch.constant.int 128 + %int2_7293 = torch.constant.int 2 + %int4_7294 = torch.constant.int 4 + %none_7295 = torch.constant.none + %cpu_7296 = torch.constant.device "cpu" + %false_7297 = torch.constant.bool false + %6000 = torch.aten.arange.start_step %int0_7291, %int128_7292, %int2_7293, %int4_7294, %none_7295, %cpu_7296, %false_7297 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_7298 = torch.constant.int 6 + %6001 = torch.prims.convert_element_type %6000, %int6_7298 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_7299 = torch.constant.int 128 + %6002 = torch.aten.div.Scalar %6001, %int128_7299 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_7300 = torch.constant.float 5.000000e+05 + %6003 = torch.aten.pow.Scalar %float5.000000e05_7300, %6002 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6004 = torch.aten.reciprocal %6003 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_7301 = torch.constant.float 1.000000e+00 + %6005 = torch.aten.mul.Scalar %6004, %float1.000000e00_7301 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %6006 = torch.aten.reciprocal %6005 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_7302 = torch.constant.float 6.2831853071795862 + %6007 = torch.aten.mul.Scalar %6006, %float6.283190e00_7302 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_7303 = torch.constant.float 8.192000e+03 + %6008 = torch.aten.gt.Scalar %6007, %float8.192000e03_7303 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_7304 = torch.constant.int 8 + %6009 = torch.aten.div.Scalar %6005, %int8_7304 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6010 = torch.aten.where.self %6008, %6009, %6005 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6011 = torch.aten.reciprocal %6007 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_7305 = torch.constant.int 8192 + %6012 = torch.aten.mul.Scalar %6011, %int8192_7305 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_7306 = torch.constant.int 1 + %int1_7307 = torch.constant.int 1 + %6013 = torch.aten.sub.Scalar %6012, %int1_7306, %int1_7307 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_7308 = torch.constant.int 3 + %6014 = torch.aten.div.Scalar %6013, %int3_7308 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_7309 = torch.constant.int 1 + %int1_7310 = torch.constant.int 1 + %6015 = torch.aten.rsub.Scalar %6014, %int1_7309, %int1_7310 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %6016 = torch.aten.mul.Tensor %6015, %6010 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_7311 = torch.constant.int 8 + %6017 = torch.aten.div.Scalar %6016, %int8_7311 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6018 = torch.aten.mul.Tensor %6014, %6010 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_7312 = torch.constant.int 1 + %6019 = torch.aten.add.Tensor %6017, %6018, %int1_7312 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_7313 = torch.constant.float 2.048000e+03 + %6020 = torch.aten.lt.Scalar %6007, %float2.048000e03_7313 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6021 = torch.aten.bitwise_not %6020 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_7314 = torch.constant.float 8.192000e+03 + %6022 = torch.aten.gt.Scalar %6007, %float8.192000e03_7314 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6023 = torch.aten.bitwise_not %6022 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6024 = torch.aten.mul.Tensor %6021, %6023 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6025 = torch.aten.where.self %6024, %6019, %6010 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6026 = torch.prim.ListConstruct %6025, %6025 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list %int-1_7315 = torch.constant.int -1 - %6002 = torch.aten.transpose.int %261, %int-2_7314, %int-1_7315 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_7316 = torch.constant.int 1 - %6003 = torch.aten.size.int %5992, %int1_7316 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_7317 = torch.constant.int 4 - %6004 = torch.aten.mul.int %int4_7317, %6003 : !torch.int, !torch.int -> !torch.int - %int14336_7318 = torch.constant.int 14336 - %6005 = torch.prim.ListConstruct %6004, %int14336_7318 : (!torch.int, !torch.int) -> !torch.list - %6006 = torch.aten.view %6001, %6005 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %6006, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %6007 = torch.aten.mm %6006, %6002 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6007, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_7319 = torch.constant.int 4 - %int4096_7320 = torch.constant.int 4096 - %6008 = torch.prim.ListConstruct %int4_7319, %6003, %int4096_7320 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6009 = torch.aten.view %6007, %6008 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6009, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_7321 = torch.constant.int 1 - %6010 = torch.aten.add.Tensor %5975, %6009, %int1_7321 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6010, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_7322 = torch.constant.int 6 - %6011 = torch.prims.convert_element_type %6010, %int6_7322 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6011, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_7323 = torch.constant.int 2 - %6012 = torch.aten.pow.Tensor_Scalar %6011, %int2_7323 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6012, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_7324 = torch.constant.int -1 - %6013 = torch.prim.ListConstruct %int-1_7324 : (!torch.int) -> !torch.list - %true_7325 = torch.constant.bool true - %none_7326 = torch.constant.none - %6014 = torch.aten.mean.dim %6012, %6013, %true_7325, %none_7326 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6014, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_7327 = torch.constant.float 9.9999997473787516E-6 - %int1_7328 = torch.constant.int 1 - %6015 = torch.aten.add.Scalar %6014, %float9.999990e-06_7327, %int1_7328 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6015, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %6016 = torch.aten.rsqrt %6015 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6016, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %6017 = torch.aten.mul.Tensor %6011, %6016 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6017, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7329 = torch.constant.int 5 - %6018 = torch.prims.convert_element_type %6017, %int5_7329 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6018, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %6019 = torch.aten.mul.Tensor %262, %6018 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6019, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7330 = torch.constant.int 5 - %6020 = torch.prims.convert_element_type %6019, %int5_7330 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6020, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7331 = torch.constant.int -2 - %int-1_7332 = torch.constant.int -1 - %6021 = torch.aten.transpose.int %263, %int-2_7331, %int-1_7332 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7333 = torch.constant.int 4 - %6022 = torch.aten.mul.int %int4_7333, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7334 = torch.constant.int 4096 - %6023 = torch.prim.ListConstruct %6022, %int4096_7334 : (!torch.int, !torch.int) -> !torch.list - %6024 = torch.aten.view %6020, %6023 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6024, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6025 = torch.aten.mm %6024, %6021 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6025, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_7335 = torch.constant.int 4 - %int4096_7336 = torch.constant.int 4096 - %6026 = torch.prim.ListConstruct %int4_7335, %306, %int4096_7336 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6027 = torch.aten.view %6025, %6026 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6027, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7337 = torch.constant.int -2 - %int-1_7338 = torch.constant.int -1 - %6028 = torch.aten.transpose.int %264, %int-2_7337, %int-1_7338 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_7339 = torch.constant.int 4 - %6029 = torch.aten.mul.int %int4_7339, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7340 = torch.constant.int 4096 - %6030 = torch.prim.ListConstruct %6029, %int4096_7340 : (!torch.int, !torch.int) -> !torch.list - %6031 = torch.aten.view %6020, %6030 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6031, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6032 = torch.aten.mm %6031, %6028 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %6032, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_7341 = torch.constant.int 4 - %int1024_7342 = torch.constant.int 1024 - %6033 = torch.prim.ListConstruct %int4_7341, %306, %int1024_7342 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6034 = torch.aten.view %6032, %6033 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %6034, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_7343 = torch.constant.int -2 - %int-1_7344 = torch.constant.int -1 - %6035 = torch.aten.transpose.int %265, %int-2_7343, %int-1_7344 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_7345 = torch.constant.int 4 - %6036 = torch.aten.mul.int %int4_7345, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7346 = torch.constant.int 4096 - %6037 = torch.prim.ListConstruct %6036, %int4096_7346 : (!torch.int, !torch.int) -> !torch.list - %6038 = torch.aten.view %6020, %6037 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6038, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6039 = torch.aten.mm %6038, %6035 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %6039, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %6027 = torch.aten.cat %6026, %int-1_7315 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_7316 = torch.constant.int 6 + %6028 = torch.prims.convert_element_type %6027, %int6_7316 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_7317 = torch.constant.int 1 + %6029 = torch.aten.unsqueeze %5999, %int1_7317 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_7318 = torch.constant.int 6 + %6030 = torch.prims.convert_element_type %6029, %int6_7318 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_7319 = torch.constant.int 0 + %6031 = torch.aten.unsqueeze %6028, %int0_7319 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_7320 = torch.constant.int 6 + %6032 = torch.prims.convert_element_type %6031, %int6_7320 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %6033 = torch.aten.mul.Tensor %6030, %6032 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %6034 = torch.aten.cos %6033 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_7321 = torch.constant.int 5 + %6035 = torch.prims.convert_element_type %6034, %int5_7321 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %6036 = torch.aten.sin %6033 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_7322 = torch.constant.int 5 + %6037 = torch.prims.convert_element_type %6036, %int5_7322 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_7323 = torch.constant.int 0 + %int0_7324 = torch.constant.int 0 + %int1_7325 = torch.constant.int 1 + %6038 = torch.aten.slice.Tensor %6035, %int0_7323, %int0_7324, %298, %int1_7325 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6038, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_7326 = torch.constant.int 1 + %int0_7327 = torch.constant.int 0 + %int9223372036854775807_7328 = torch.constant.int 9223372036854775807 + %int1_7329 = torch.constant.int 1 + %6039 = torch.aten.slice.Tensor %6038, %int1_7326, %int0_7327, %int9223372036854775807_7328, %int1_7329 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6039, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_7330 = torch.constant.int 0 + %int0_7331 = torch.constant.int 0 + %int1_7332 = torch.constant.int 1 + %6040 = torch.aten.slice.Tensor %6037, %int0_7330, %int0_7331, %298, %int1_7332 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6040, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_7333 = torch.constant.int 1 + %int0_7334 = torch.constant.int 0 + %int9223372036854775807_7335 = torch.constant.int 9223372036854775807 + %int1_7336 = torch.constant.int 1 + %6041 = torch.aten.slice.Tensor %6040, %int1_7333, %int0_7334, %int9223372036854775807_7335, %int1_7336 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6041, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_7337 = torch.constant.int 0 + %6042 = torch.aten.unsqueeze %6039, %int0_7337 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6042, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_7338 = torch.constant.int 1 + %int0_7339 = torch.constant.int 0 + %int9223372036854775807_7340 = torch.constant.int 9223372036854775807 + %int1_7341 = torch.constant.int 1 + %6043 = torch.aten.slice.Tensor %6042, %int1_7338, %int0_7339, %int9223372036854775807_7340, %int1_7341 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6043, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_7342 = torch.constant.int 2 + %6044 = torch.aten.unsqueeze %6043, %int2_7342 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6044, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_7343 = torch.constant.int 3 + %int0_7344 = torch.constant.int 0 + %int9223372036854775807_7345 = torch.constant.int 9223372036854775807 + %int1_7346 = torch.constant.int 1 + %6045 = torch.aten.slice.Tensor %6044, %int3_7343, %int0_7344, %int9223372036854775807_7345, %int1_7346 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6045, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> %int4_7347 = torch.constant.int 4 - %int1024_7348 = torch.constant.int 1024 - %6040 = torch.prim.ListConstruct %int4_7347, %306, %int1024_7348 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6041 = torch.aten.view %6039, %6040 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %6041, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_7349 = torch.constant.int 4 - %int32_7350 = torch.constant.int 32 - %int128_7351 = torch.constant.int 128 - %6042 = torch.prim.ListConstruct %int4_7349, %306, %int32_7350, %int128_7351 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6043 = torch.aten.view %6027, %6042 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6043, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_7352 = torch.constant.int 4 - %int8_7353 = torch.constant.int 8 - %int128_7354 = torch.constant.int 128 - %6044 = torch.prim.ListConstruct %int4_7352, %306, %int8_7353, %int128_7354 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6045 = torch.aten.view %6034, %6044 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6045, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_7355 = torch.constant.int 4 - %int8_7356 = torch.constant.int 8 - %int128_7357 = torch.constant.int 128 - %6046 = torch.prim.ListConstruct %int4_7355, %306, %int8_7356, %int128_7357 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6047 = torch.aten.view %6041, %6046 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6047, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_7358 = torch.constant.int 131072 - %none_7359 = torch.constant.none - %none_7360 = torch.constant.none - %cpu_7361 = torch.constant.device "cpu" - %false_7362 = torch.constant.bool false - %6048 = torch.aten.arange %int131072_7358, %none_7359, %none_7360, %cpu_7361, %false_7362 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_7363 = torch.constant.int 0 - %int128_7364 = torch.constant.int 128 - %none_7365 = torch.constant.none - %none_7366 = torch.constant.none - %cpu_7367 = torch.constant.device "cpu" - %false_7368 = torch.constant.bool false - %6049 = torch.aten.arange.start %int0_7363, %int128_7364, %none_7365, %none_7366, %cpu_7367, %false_7368 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_7369 = torch.constant.int 2 - %6050 = torch.aten.floor_divide.Scalar %6049, %int2_7369 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_7370 = torch.constant.int 6 - %6051 = torch.prims.convert_element_type %6050, %int6_7370 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_7371 = torch.constant.int 128 - %6052 = torch.aten.div.Scalar %6051, %int128_7371 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_7372 = torch.constant.float 2.000000e+00 - %6053 = torch.aten.mul.Scalar %6052, %float2.000000e00_7372 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_7373 = torch.constant.float 5.000000e+05 - %6054 = torch.aten.pow.Scalar %float5.000000e05_7373, %6053 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %6055 = torch.aten.reciprocal %6054 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_7374 = torch.constant.float 1.000000e+00 - %6056 = torch.aten.mul.Scalar %6055, %float1.000000e00_7374 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_7375 = torch.constant.int 1 - %6057 = torch.aten.unsqueeze %6048, %int1_7375 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_7376 = torch.constant.int 0 - %6058 = torch.aten.unsqueeze %6056, %int0_7376 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %6059 = torch.aten.mul.Tensor %6057, %6058 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_7377 = torch.constant.int 1 - %6060 = torch.aten.size.int %6027, %int1_7377 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_7378 = torch.constant.int 0 - %6061 = torch.aten.add.int %int0_7378, %6060 : !torch.int, !torch.int -> !torch.int - %int0_7379 = torch.constant.int 0 + %int1_7348 = torch.constant.int 1 + %int1_7349 = torch.constant.int 1 + %int1_7350 = torch.constant.int 1 + %6046 = torch.prim.ListConstruct %int4_7347, %int1_7348, %int1_7349, %int1_7350 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6047 = torch.aten.repeat %6045, %6046 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6047, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_7351 = torch.constant.int 0 + %6048 = torch.aten.unsqueeze %6041, %int0_7351 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6048, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_7352 = torch.constant.int 1 + %int0_7353 = torch.constant.int 0 + %int9223372036854775807_7354 = torch.constant.int 9223372036854775807 + %int1_7355 = torch.constant.int 1 + %6049 = torch.aten.slice.Tensor %6048, %int1_7352, %int0_7353, %int9223372036854775807_7354, %int1_7355 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6049, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_7356 = torch.constant.int 2 + %6050 = torch.aten.unsqueeze %6049, %int2_7356 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6050, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_7357 = torch.constant.int 3 + %int0_7358 = torch.constant.int 0 + %int9223372036854775807_7359 = torch.constant.int 9223372036854775807 + %int1_7360 = torch.constant.int 1 + %6051 = torch.aten.slice.Tensor %6050, %int3_7357, %int0_7358, %int9223372036854775807_7359, %int1_7360 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6051, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_7361 = torch.constant.int 4 + %int1_7362 = torch.constant.int 1 + %int1_7363 = torch.constant.int 1 + %int1_7364 = torch.constant.int 1 + %6052 = torch.prim.ListConstruct %int4_7361, %int1_7362, %int1_7363, %int1_7364 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6053 = torch.aten.repeat %6051, %6052 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6053, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %6054 = torch.aten.mul.Tensor %5994, %6047 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6054, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_7365 = torch.constant.int 3 + %int0_7366 = torch.constant.int 0 + %int64_7367 = torch.constant.int 64 + %int1_7368 = torch.constant.int 1 + %6055 = torch.aten.slice.Tensor %5994, %int3_7365, %int0_7366, %int64_7367, %int1_7368 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %6055, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_7369 = torch.constant.int 3 + %int64_7370 = torch.constant.int 64 + %int9223372036854775807_7371 = torch.constant.int 9223372036854775807 + %int1_7372 = torch.constant.int 1 + %6056 = torch.aten.slice.Tensor %5994, %int3_7369, %int64_7370, %int9223372036854775807_7371, %int1_7372 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %6056, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %6057 = torch.aten.neg %6056 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %6057, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %6058 = torch.prim.ListConstruct %6057, %6055 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_7373 = torch.constant.int -1 + %6059 = torch.aten.cat %6058, %int-1_7373 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6059, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %6060 = torch.aten.mul.Tensor %6059, %6053 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6060, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_7374 = torch.constant.int 1 + %6061 = torch.aten.add.Tensor %6054, %6060, %int1_7374 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6061, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_7375 = torch.constant.int 131072 + %none_7376 = torch.constant.none + %none_7377 = torch.constant.none + %cpu_7378 = torch.constant.device "cpu" + %false_7379 = torch.constant.bool false + %6062 = torch.aten.arange %int131072_7375, %none_7376, %none_7377, %cpu_7378, %false_7379 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> %int0_7380 = torch.constant.int 0 - %int1_7381 = torch.constant.int 1 - %6062 = torch.aten.slice.Tensor %6059, %int0_7379, %int0_7380, %6061, %int1_7381 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6062, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_7382 = torch.constant.int 1 - %int0_7383 = torch.constant.int 0 - %int9223372036854775807_7384 = torch.constant.int 9223372036854775807 - %int1_7385 = torch.constant.int 1 - %6063 = torch.aten.slice.Tensor %6062, %int1_7382, %int0_7383, %int9223372036854775807_7384, %int1_7385 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6063, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_7386 = torch.constant.int 1 - %int0_7387 = torch.constant.int 0 - %int9223372036854775807_7388 = torch.constant.int 9223372036854775807 - %int1_7389 = torch.constant.int 1 - %6064 = torch.aten.slice.Tensor %6063, %int1_7386, %int0_7387, %int9223372036854775807_7388, %int1_7389 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6064, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_7390 = torch.constant.int 0 - %6065 = torch.aten.unsqueeze %6064, %int0_7390 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6065, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_7391 = torch.constant.int 1 - %int0_7392 = torch.constant.int 0 - %int9223372036854775807_7393 = torch.constant.int 9223372036854775807 - %int1_7394 = torch.constant.int 1 - %6066 = torch.aten.slice.Tensor %6065, %int1_7391, %int0_7392, %int9223372036854775807_7393, %int1_7394 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6066, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_7395 = torch.constant.int 2 - %int0_7396 = torch.constant.int 0 - %int9223372036854775807_7397 = torch.constant.int 9223372036854775807 + %int128_7381 = torch.constant.int 128 + %int2_7382 = torch.constant.int 2 + %int4_7383 = torch.constant.int 4 + %none_7384 = torch.constant.none + %cpu_7385 = torch.constant.device "cpu" + %false_7386 = torch.constant.bool false + %6063 = torch.aten.arange.start_step %int0_7380, %int128_7381, %int2_7382, %int4_7383, %none_7384, %cpu_7385, %false_7386 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_7387 = torch.constant.int 6 + %6064 = torch.prims.convert_element_type %6063, %int6_7387 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_7388 = torch.constant.int 128 + %6065 = torch.aten.div.Scalar %6064, %int128_7388 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_7389 = torch.constant.float 5.000000e+05 + %6066 = torch.aten.pow.Scalar %float5.000000e05_7389, %6065 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6067 = torch.aten.reciprocal %6066 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_7390 = torch.constant.float 1.000000e+00 + %6068 = torch.aten.mul.Scalar %6067, %float1.000000e00_7390 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %6069 = torch.aten.reciprocal %6068 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_7391 = torch.constant.float 6.2831853071795862 + %6070 = torch.aten.mul.Scalar %6069, %float6.283190e00_7391 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_7392 = torch.constant.float 8.192000e+03 + %6071 = torch.aten.gt.Scalar %6070, %float8.192000e03_7392 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_7393 = torch.constant.int 8 + %6072 = torch.aten.div.Scalar %6068, %int8_7393 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6073 = torch.aten.where.self %6071, %6072, %6068 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6074 = torch.aten.reciprocal %6070 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_7394 = torch.constant.int 8192 + %6075 = torch.aten.mul.Scalar %6074, %int8192_7394 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_7395 = torch.constant.int 1 + %int1_7396 = torch.constant.int 1 + %6076 = torch.aten.sub.Scalar %6075, %int1_7395, %int1_7396 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_7397 = torch.constant.int 3 + %6077 = torch.aten.div.Scalar %6076, %int3_7397 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_7398 = torch.constant.int 1 - %6067 = torch.aten.slice.Tensor %6066, %int2_7395, %int0_7396, %int9223372036854775807_7397, %int1_7398 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6067, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_7399 = torch.constant.int 4 - %int1_7400 = torch.constant.int 1 + %int1_7399 = torch.constant.int 1 + %6078 = torch.aten.rsub.Scalar %6077, %int1_7398, %int1_7399 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %6079 = torch.aten.mul.Tensor %6078, %6073 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_7400 = torch.constant.int 8 + %6080 = torch.aten.div.Scalar %6079, %int8_7400 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6081 = torch.aten.mul.Tensor %6077, %6073 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> %int1_7401 = torch.constant.int 1 - %6068 = torch.prim.ListConstruct %int4_7399, %int1_7400, %int1_7401 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6069 = torch.aten.repeat %6067, %6068 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %6069, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_7402 = torch.constant.int 6 - %6070 = torch.prims.convert_element_type %6043, %int6_7402 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %6070, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %6071 = torch_c.to_builtin_tensor %6070 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %6072 = torch_c.to_builtin_tensor %6069 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %6073 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%6071, %6072) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %6074 = torch_c.from_builtin_tensor %6073 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %6074, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_7403 = torch.constant.int 5 - %6075 = torch.prims.convert_element_type %6074, %int5_7403 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6075, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_7404 = torch.constant.int 131072 - %none_7405 = torch.constant.none - %none_7406 = torch.constant.none - %cpu_7407 = torch.constant.device "cpu" - %false_7408 = torch.constant.bool false - %6076 = torch.aten.arange %int131072_7404, %none_7405, %none_7406, %cpu_7407, %false_7408 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_7409 = torch.constant.int 0 - %int128_7410 = torch.constant.int 128 - %none_7411 = torch.constant.none - %none_7412 = torch.constant.none - %cpu_7413 = torch.constant.device "cpu" - %false_7414 = torch.constant.bool false - %6077 = torch.aten.arange.start %int0_7409, %int128_7410, %none_7411, %none_7412, %cpu_7413, %false_7414 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_7415 = torch.constant.int 2 - %6078 = torch.aten.floor_divide.Scalar %6077, %int2_7415 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_7416 = torch.constant.int 6 - %6079 = torch.prims.convert_element_type %6078, %int6_7416 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_7417 = torch.constant.int 128 - %6080 = torch.aten.div.Scalar %6079, %int128_7417 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_7418 = torch.constant.float 2.000000e+00 - %6081 = torch.aten.mul.Scalar %6080, %float2.000000e00_7418 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_7419 = torch.constant.float 5.000000e+05 - %6082 = torch.aten.pow.Scalar %float5.000000e05_7419, %6081 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %6083 = torch.aten.reciprocal %6082 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_7420 = torch.constant.float 1.000000e+00 - %6084 = torch.aten.mul.Scalar %6083, %float1.000000e00_7420 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %6082 = torch.aten.add.Tensor %6080, %6081, %int1_7401 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_7402 = torch.constant.float 2.048000e+03 + %6083 = torch.aten.lt.Scalar %6070, %float2.048000e03_7402 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6084 = torch.aten.bitwise_not %6083 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_7403 = torch.constant.float 8.192000e+03 + %6085 = torch.aten.gt.Scalar %6070, %float8.192000e03_7403 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6086 = torch.aten.bitwise_not %6085 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6087 = torch.aten.mul.Tensor %6084, %6086 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6088 = torch.aten.where.self %6087, %6082, %6073 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6089 = torch.prim.ListConstruct %6088, %6088 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_7404 = torch.constant.int -1 + %6090 = torch.aten.cat %6089, %int-1_7404 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_7405 = torch.constant.int 6 + %6091 = torch.prims.convert_element_type %6090, %int6_7405 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_7406 = torch.constant.int 1 + %6092 = torch.aten.unsqueeze %6062, %int1_7406 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_7407 = torch.constant.int 6 + %6093 = torch.prims.convert_element_type %6092, %int6_7407 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_7408 = torch.constant.int 0 + %6094 = torch.aten.unsqueeze %6091, %int0_7408 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_7409 = torch.constant.int 6 + %6095 = torch.prims.convert_element_type %6094, %int6_7409 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %6096 = torch.aten.mul.Tensor %6093, %6095 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %6097 = torch.aten.cos %6096 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_7410 = torch.constant.int 5 + %6098 = torch.prims.convert_element_type %6097, %int5_7410 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %6099 = torch.aten.sin %6096 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_7411 = torch.constant.int 5 + %6100 = torch.prims.convert_element_type %6099, %int5_7411 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_7412 = torch.constant.int 0 + %int0_7413 = torch.constant.int 0 + %int1_7414 = torch.constant.int 1 + %6101 = torch.aten.slice.Tensor %6098, %int0_7412, %int0_7413, %298, %int1_7414 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6101, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_7415 = torch.constant.int 1 + %int0_7416 = torch.constant.int 0 + %int9223372036854775807_7417 = torch.constant.int 9223372036854775807 + %int1_7418 = torch.constant.int 1 + %6102 = torch.aten.slice.Tensor %6101, %int1_7415, %int0_7416, %int9223372036854775807_7417, %int1_7418 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6102, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_7419 = torch.constant.int 0 + %int0_7420 = torch.constant.int 0 %int1_7421 = torch.constant.int 1 - %6085 = torch.aten.unsqueeze %6076, %int1_7421 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_7422 = torch.constant.int 0 - %6086 = torch.aten.unsqueeze %6084, %int0_7422 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %6087 = torch.aten.mul.Tensor %6085, %6086 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_7423 = torch.constant.int 1 - %6088 = torch.aten.size.int %6034, %int1_7423 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_7424 = torch.constant.int 0 - %6089 = torch.aten.add.int %int0_7424, %6088 : !torch.int, !torch.int -> !torch.int - %int0_7425 = torch.constant.int 0 + %6103 = torch.aten.slice.Tensor %6100, %int0_7419, %int0_7420, %298, %int1_7421 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6103, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_7422 = torch.constant.int 1 + %int0_7423 = torch.constant.int 0 + %int9223372036854775807_7424 = torch.constant.int 9223372036854775807 + %int1_7425 = torch.constant.int 1 + %6104 = torch.aten.slice.Tensor %6103, %int1_7422, %int0_7423, %int9223372036854775807_7424, %int1_7425 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6104, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int0_7426 = torch.constant.int 0 + %6105 = torch.aten.unsqueeze %6102, %int0_7426 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6105, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int1_7427 = torch.constant.int 1 - %6090 = torch.aten.slice.Tensor %6087, %int0_7425, %int0_7426, %6089, %int1_7427 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6090, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_7428 = torch.constant.int 1 - %int0_7429 = torch.constant.int 0 - %int9223372036854775807_7430 = torch.constant.int 9223372036854775807 - %int1_7431 = torch.constant.int 1 - %6091 = torch.aten.slice.Tensor %6090, %int1_7428, %int0_7429, %int9223372036854775807_7430, %int1_7431 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6091, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_7432 = torch.constant.int 1 + %int0_7428 = torch.constant.int 0 + %int9223372036854775807_7429 = torch.constant.int 9223372036854775807 + %int1_7430 = torch.constant.int 1 + %6106 = torch.aten.slice.Tensor %6105, %int1_7427, %int0_7428, %int9223372036854775807_7429, %int1_7430 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6106, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_7431 = torch.constant.int 2 + %6107 = torch.aten.unsqueeze %6106, %int2_7431 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6107, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_7432 = torch.constant.int 3 %int0_7433 = torch.constant.int 0 %int9223372036854775807_7434 = torch.constant.int 9223372036854775807 %int1_7435 = torch.constant.int 1 - %6092 = torch.aten.slice.Tensor %6091, %int1_7432, %int0_7433, %int9223372036854775807_7434, %int1_7435 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6092, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_7436 = torch.constant.int 0 - %6093 = torch.aten.unsqueeze %6092, %int0_7436 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6093, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %6108 = torch.aten.slice.Tensor %6107, %int3_7432, %int0_7433, %int9223372036854775807_7434, %int1_7435 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6108, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_7436 = torch.constant.int 4 %int1_7437 = torch.constant.int 1 - %int0_7438 = torch.constant.int 0 - %int9223372036854775807_7439 = torch.constant.int 9223372036854775807 - %int1_7440 = torch.constant.int 1 - %6094 = torch.aten.slice.Tensor %6093, %int1_7437, %int0_7438, %int9223372036854775807_7439, %int1_7440 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6094, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_7441 = torch.constant.int 2 + %int1_7438 = torch.constant.int 1 + %int1_7439 = torch.constant.int 1 + %6109 = torch.prim.ListConstruct %int4_7436, %int1_7437, %int1_7438, %int1_7439 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6110 = torch.aten.repeat %6108, %6109 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6110, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_7440 = torch.constant.int 0 + %6111 = torch.aten.unsqueeze %6104, %int0_7440 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6111, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_7441 = torch.constant.int 1 %int0_7442 = torch.constant.int 0 %int9223372036854775807_7443 = torch.constant.int 9223372036854775807 %int1_7444 = torch.constant.int 1 - %6095 = torch.aten.slice.Tensor %6094, %int2_7441, %int0_7442, %int9223372036854775807_7443, %int1_7444 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6095, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_7445 = torch.constant.int 4 - %int1_7446 = torch.constant.int 1 - %int1_7447 = torch.constant.int 1 - %6096 = torch.prim.ListConstruct %int4_7445, %int1_7446, %int1_7447 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6097 = torch.aten.repeat %6095, %6096 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %6097, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_7448 = torch.constant.int 6 - %6098 = torch.prims.convert_element_type %6045, %int6_7448 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %6098, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %6099 = torch_c.to_builtin_tensor %6098 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %6100 = torch_c.to_builtin_tensor %6097 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %6101 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%6099, %6100) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %6102 = torch_c.from_builtin_tensor %6101 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %6102, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_7449 = torch.constant.int 5 - %6103 = torch.prims.convert_element_type %6102, %int5_7449 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6103, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_7450 = torch.constant.int 64 - %6104 = torch.aten.mul.Scalar %arg2, %int64_7450 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6104, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int58 = torch.constant.int 58 + %6112 = torch.aten.slice.Tensor %6111, %int1_7441, %int0_7442, %int9223372036854775807_7443, %int1_7444 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6112, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_7445 = torch.constant.int 2 + %6113 = torch.aten.unsqueeze %6112, %int2_7445 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6113, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_7446 = torch.constant.int 3 + %int0_7447 = torch.constant.int 0 + %int9223372036854775807_7448 = torch.constant.int 9223372036854775807 + %int1_7449 = torch.constant.int 1 + %6114 = torch.aten.slice.Tensor %6113, %int3_7446, %int0_7447, %int9223372036854775807_7448, %int1_7449 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6114, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_7450 = torch.constant.int 4 %int1_7451 = torch.constant.int 1 - %6105 = torch.aten.add.Scalar %6104, %int58, %int1_7451 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6105, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_7452 = torch.constant.int 4 - %int32_7453 = torch.constant.int 32 - %int8_7454 = torch.constant.int 8 - %int128_7455 = torch.constant.int 128 - %6106 = torch.prim.ListConstruct %int4_7452, %398, %int32_7453, %int8_7454, %int128_7455 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6107 = torch.aten.view %6103, %6106 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6107, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_7456 = torch.constant.int 4 - %6108 = torch.aten.mul.int %int4_7456, %398 : !torch.int, !torch.int -> !torch.int - %int32_7457 = torch.constant.int 32 - %int8_7458 = torch.constant.int 8 - %int128_7459 = torch.constant.int 128 - %6109 = torch.prim.ListConstruct %6108, %int32_7457, %int8_7458, %int128_7459 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6110 = torch.aten.view %6107, %6109 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6110, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_7460 = torch.constant.int 4 - %6111 = torch.aten.mul.int %int4_7460, %398 : !torch.int, !torch.int -> !torch.int - %6112 = torch.prim.ListConstruct %6111 : (!torch.int) -> !torch.list - %6113 = torch.aten.view %6105, %6112 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %6113, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_7461 = torch.constant.int 32 - %int2_7462 = torch.constant.int 2 - %int32_7463 = torch.constant.int 32 - %int8_7464 = torch.constant.int 8 - %int128_7465 = torch.constant.int 128 - %6114 = torch.prim.ListConstruct %389, %int32_7461, %int2_7462, %int32_7463, %int8_7464, %int128_7465 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6115 = torch.aten.view %5947, %6114 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6115, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_7466 = torch.constant.int 32 - %6116 = torch.aten.mul.int %389, %int32_7466 : !torch.int, !torch.int -> !torch.int - %int2_7467 = torch.constant.int 2 - %6117 = torch.aten.mul.int %6116, %int2_7467 : !torch.int, !torch.int -> !torch.int - %int32_7468 = torch.constant.int 32 - %int8_7469 = torch.constant.int 8 - %int128_7470 = torch.constant.int 128 - %6118 = torch.prim.ListConstruct %6117, %int32_7468, %int8_7469, %int128_7470 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6119 = torch.aten.view %6115, %6118 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6119, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %6120 = torch.prim.ListConstruct %6113 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_7471 = torch.constant.bool false - %6121 = torch.aten.index_put %6119, %6120, %6110, %false_7471 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6121, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_7472 = torch.constant.int 32 - %int2_7473 = torch.constant.int 2 - %int32_7474 = torch.constant.int 32 - %int8_7475 = torch.constant.int 8 - %int128_7476 = torch.constant.int 128 - %6122 = torch.prim.ListConstruct %389, %int32_7472, %int2_7473, %int32_7474, %int8_7475, %int128_7476 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6123 = torch.aten.view %6121, %6122 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6123, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_7477 = torch.constant.int 2097152 - %6124 = torch.prim.ListConstruct %389, %int2097152_7477 : (!torch.int, !torch.int) -> !torch.list - %6125 = torch.aten.view %6123, %6124 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %6125, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_7478 = torch.constant.int 32 - %int2_7479 = torch.constant.int 2 - %int32_7480 = torch.constant.int 32 + %int1_7452 = torch.constant.int 1 + %int1_7453 = torch.constant.int 1 + %6115 = torch.prim.ListConstruct %int4_7450, %int1_7451, %int1_7452, %int1_7453 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6116 = torch.aten.repeat %6114, %6115 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6116, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %6117 = torch.aten.mul.Tensor %5996, %6110 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6117, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_7454 = torch.constant.int 3 + %int0_7455 = torch.constant.int 0 + %int64_7456 = torch.constant.int 64 + %int1_7457 = torch.constant.int 1 + %6118 = torch.aten.slice.Tensor %5996, %int3_7454, %int0_7455, %int64_7456, %int1_7457 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %6118, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_7458 = torch.constant.int 3 + %int64_7459 = torch.constant.int 64 + %int9223372036854775807_7460 = torch.constant.int 9223372036854775807 + %int1_7461 = torch.constant.int 1 + %6119 = torch.aten.slice.Tensor %5996, %int3_7458, %int64_7459, %int9223372036854775807_7460, %int1_7461 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %6119, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %6120 = torch.aten.neg %6119 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %6120, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %6121 = torch.prim.ListConstruct %6120, %6118 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_7462 = torch.constant.int -1 + %6122 = torch.aten.cat %6121, %int-1_7462 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6122, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %6123 = torch.aten.mul.Tensor %6122, %6116 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6123, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_7463 = torch.constant.int 1 + %6124 = torch.aten.add.Tensor %6117, %6123, %int1_7463 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6124, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_7464 = torch.constant.int 32 + %6125 = torch.aten.mul.Scalar %arg2, %int32_7464 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6125, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int21 = torch.constant.int 21 + %int1_7465 = torch.constant.int 1 + %6126 = torch.aten.add.Scalar %6125, %int21, %int1_7465 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6126, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_7466 = torch.constant.int 2 + %6127 = torch.aten.mul.Scalar %6126, %int2_7466 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6127, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_7467 = torch.constant.int 0 + %int1_7468 = torch.constant.int 1 + %6128 = torch.aten.add.Scalar %6127, %int0_7467, %int1_7468 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6128, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %6129 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %6130 = torch.aten.view %6128, %6129 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %6130, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_7469 = torch.constant.int 4 + %int32_7470 = torch.constant.int 32 + %int8_7471 = torch.constant.int 8 + %int128_7472 = torch.constant.int 128 + %6131 = torch.prim.ListConstruct %int4_7469, %296, %int32_7470, %int8_7471, %int128_7472 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6132 = torch.aten.view %6124, %6131 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6132, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_7473 = torch.constant.int 32 + %int8_7474 = torch.constant.int 8 + %int128_7475 = torch.constant.int 128 + %6133 = torch.prim.ListConstruct %504, %int32_7473, %int8_7474, %int128_7475 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6134 = torch.aten.view %6132, %6133 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %6134, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_7476 = torch.constant.int 1 + %int2_7477 = torch.constant.int 2 + %6135 = torch.aten.transpose.int %6134, %int1_7476, %int2_7477 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6135, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_7478 = torch.constant.int 5 + %6136 = torch.prims.convert_element_type %6135, %int5_7478 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6136, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_7479 = torch.constant.int 32 + %int2_7480 = torch.constant.int 2 %int8_7481 = torch.constant.int 8 - %int128_7482 = torch.constant.int 128 - %6126 = torch.prim.ListConstruct %389, %int32_7478, %int2_7479, %int32_7480, %int8_7481, %int128_7482 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6127 = torch.aten.view %6125, %6126 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6127, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_7483 = torch.constant.int 32 + %int32_7482 = torch.constant.int 32 + %int128_7483 = torch.constant.int 128 + %6137 = torch.prim.ListConstruct %297, %int32_7479, %int2_7480, %int8_7481, %int32_7482, %int128_7483 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6138 = torch.aten.view %5900, %6137 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6138, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> %int8_7484 = torch.constant.int 8 - %int128_7485 = torch.constant.int 128 - %6128 = torch.prim.ListConstruct %6117, %int32_7483, %int8_7484, %int128_7485 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6129 = torch.aten.view %6127, %6128 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6129, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_7486 = torch.constant.int 4 - %int32_7487 = torch.constant.int 32 - %int8_7488 = torch.constant.int 8 - %int128_7489 = torch.constant.int 128 - %6130 = torch.prim.ListConstruct %int4_7486, %398, %int32_7487, %int8_7488, %int128_7489 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6131 = torch.aten.view %6047, %6130 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6131, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_7490 = torch.constant.int 4 - %6132 = torch.aten.mul.int %int4_7490, %398 : !torch.int, !torch.int -> !torch.int + %int32_7485 = torch.constant.int 32 + %int128_7486 = torch.constant.int 128 + %6139 = torch.prim.ListConstruct %497, %int8_7484, %int32_7485, %int128_7486 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6140 = torch.aten.view %6138, %6139 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6140, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %6141 = torch.prim.ListConstruct %6130 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_7487 = torch.constant.bool false + %6142 = torch.aten.index_put %6140, %6141, %6136, %false_7487 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6142, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_7488 = torch.constant.int 32 + %int2_7489 = torch.constant.int 2 + %int8_7490 = torch.constant.int 8 %int32_7491 = torch.constant.int 32 - %int8_7492 = torch.constant.int 8 - %int128_7493 = torch.constant.int 128 - %6133 = torch.prim.ListConstruct %6132, %int32_7491, %int8_7492, %int128_7493 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6134 = torch.aten.view %6131, %6133 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6134, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_7494 = torch.constant.int 1 - %int1_7495 = torch.constant.int 1 - %6135 = torch.aten.add.Scalar %6105, %int1_7494, %int1_7495 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6135, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_7496 = torch.constant.int 4 - %6136 = torch.aten.mul.int %int4_7496, %398 : !torch.int, !torch.int -> !torch.int - %6137 = torch.prim.ListConstruct %6136 : (!torch.int) -> !torch.list - %6138 = torch.aten.view %6135, %6137 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %6138, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %6139 = torch.prim.ListConstruct %6138 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_7497 = torch.constant.bool false - %6140 = torch.aten.index_put %6129, %6139, %6134, %false_7497 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6140, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_7498 = torch.constant.int 32 - %int2_7499 = torch.constant.int 2 + %int128_7492 = torch.constant.int 128 + %6143 = torch.prim.ListConstruct %297, %int32_7488, %int2_7489, %int8_7490, %int32_7491, %int128_7492 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6144 = torch.aten.view %6142, %6143 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6144, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_7493 = torch.constant.int 2097152 + %6145 = torch.prim.ListConstruct %297, %int2097152_7493 : (!torch.int, !torch.int) -> !torch.list + %6146 = torch.aten.view %6144, %6145 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6146, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_7494 = torch.constant.int 32 + %int2_7495 = torch.constant.int 2 + %int8_7496 = torch.constant.int 8 + %int32_7497 = torch.constant.int 32 + %int128_7498 = torch.constant.int 128 + %6147 = torch.prim.ListConstruct %297, %int32_7494, %int2_7495, %int8_7496, %int32_7497, %int128_7498 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6148 = torch.aten.view %6146, %6147 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6148, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_7499 = torch.constant.int 8 %int32_7500 = torch.constant.int 32 - %int8_7501 = torch.constant.int 8 - %int128_7502 = torch.constant.int 128 - %6141 = torch.prim.ListConstruct %389, %int32_7498, %int2_7499, %int32_7500, %int8_7501, %int128_7502 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6142 = torch.aten.view %6140, %6141 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6142, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_7503 = torch.constant.int 2097152 - %6143 = torch.prim.ListConstruct %389, %int2097152_7503 : (!torch.int, !torch.int) -> !torch.list - %6144 = torch.aten.view %6142, %6143 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %6144, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_7504 = torch.constant.int -2 - %6145 = torch.aten.unsqueeze %6103, %int-2_7504 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6145, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_7505 = torch.constant.int 4 - %int8_7506 = torch.constant.int 8 - %int4_7507 = torch.constant.int 4 - %int128_7508 = torch.constant.int 128 - %6146 = torch.prim.ListConstruct %int4_7505, %6088, %int8_7506, %int4_7507, %int128_7508 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7509 = torch.constant.bool false - %6147 = torch.aten.expand %6145, %6146, %false_7509 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6147, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_7510 = torch.constant.int 0 - %6148 = torch.aten.clone %6147, %int0_7510 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6148, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7511 = torch.constant.int 4 + %int128_7501 = torch.constant.int 128 + %6149 = torch.prim.ListConstruct %497, %int8_7499, %int32_7500, %int128_7501 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6150 = torch.aten.view %6148, %6149 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6150, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_7502 = torch.constant.int 32 + %6151 = torch.aten.mul.Scalar %arg2, %int32_7502 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6151, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int21_7503 = torch.constant.int 21 + %int1_7504 = torch.constant.int 1 + %6152 = torch.aten.add.Scalar %6151, %int21_7503, %int1_7504 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6152, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_7505 = torch.constant.int 2 + %6153 = torch.aten.mul.Scalar %6152, %int2_7505 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6153, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_7506 = torch.constant.int 1 + %int1_7507 = torch.constant.int 1 + %6154 = torch.aten.add.Scalar %6153, %int1_7506, %int1_7507 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6154, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %6155 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %6156 = torch.aten.view %6154, %6155 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %6156, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_7508 = torch.constant.int 4 + %int32_7509 = torch.constant.int 32 + %int8_7510 = torch.constant.int 8 + %int128_7511 = torch.constant.int 128 + %6157 = torch.prim.ListConstruct %int4_7508, %296, %int32_7509, %int8_7510, %int128_7511 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6158 = torch.aten.view %5998, %6157 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6158, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int32_7512 = torch.constant.int 32 - %int128_7513 = torch.constant.int 128 - %6149 = torch.prim.ListConstruct %int4_7511, %6088, %int32_7512, %int128_7513 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6150 = torch.aten._unsafe_view %6148, %6149 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6150, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_7514 = torch.constant.int -2 - %6151 = torch.aten.unsqueeze %6047, %int-2_7514 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6151, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int8_7513 = torch.constant.int 8 + %int128_7514 = torch.constant.int 128 + %6159 = torch.prim.ListConstruct %504, %int32_7512, %int8_7513, %int128_7514 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6160 = torch.aten.view %6158, %6159 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %6160, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> %int1_7515 = torch.constant.int 1 - %6152 = torch.aten.size.int %6041, %int1_7515 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_7516 = torch.constant.int 4 - %int8_7517 = torch.constant.int 8 - %int4_7518 = torch.constant.int 4 - %int128_7519 = torch.constant.int 128 - %6153 = torch.prim.ListConstruct %int4_7516, %6152, %int8_7517, %int4_7518, %int128_7519 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7520 = torch.constant.bool false - %6154 = torch.aten.expand %6151, %6153, %false_7520 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6154, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_7521 = torch.constant.int 0 - %6155 = torch.aten.clone %6154, %int0_7521 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6155, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7522 = torch.constant.int 4 - %int32_7523 = torch.constant.int 32 - %int128_7524 = torch.constant.int 128 - %6156 = torch.prim.ListConstruct %int4_7522, %6152, %int32_7523, %int128_7524 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6157 = torch.aten._unsafe_view %6155, %6156 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6157, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_7525 = torch.constant.int 1 - %int2_7526 = torch.constant.int 2 - %6158 = torch.aten.transpose.int %6075, %int1_7525, %int2_7526 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6158, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_7527 = torch.constant.int 1 - %int2_7528 = torch.constant.int 2 - %6159 = torch.aten.transpose.int %6150, %int1_7527, %int2_7528 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6159, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_7529 = torch.constant.int 1 - %int2_7530 = torch.constant.int 2 - %6160 = torch.aten.transpose.int %6157, %int1_7529, %int2_7530 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6160, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_7531 = torch.constant.float 0.000000e+00 - %true_7532 = torch.constant.bool true - %none_7533 = torch.constant.none - %none_7534 = torch.constant.none - %6161:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6158, %6159, %6160, %float0.000000e00_7531, %true_7532, %none_7533, %none_7534) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %6161#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_7535 = torch.constant.int 1 - %int2_7536 = torch.constant.int 2 - %6162 = torch.aten.transpose.int %6161#0, %int1_7535, %int2_7536 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6162, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_7537 = torch.constant.int 4 - %int4096_7538 = torch.constant.int 4096 - %6163 = torch.prim.ListConstruct %int4_7537, %6060, %int4096_7538 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6164 = torch.aten.view %6162, %6163 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6164, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7539 = torch.constant.int -2 - %int-1_7540 = torch.constant.int -1 - %6165 = torch.aten.transpose.int %266, %int-2_7539, %int-1_7540 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7541 = torch.constant.int 4 - %6166 = torch.aten.mul.int %int4_7541, %6060 : !torch.int, !torch.int -> !torch.int - %int4096_7542 = torch.constant.int 4096 - %6167 = torch.prim.ListConstruct %6166, %int4096_7542 : (!torch.int, !torch.int) -> !torch.list - %6168 = torch.aten.view %6164, %6167 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6168, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6169 = torch.aten.mm %6168, %6165 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6169, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_7543 = torch.constant.int 4 - %int4096_7544 = torch.constant.int 4096 - %6170 = torch.prim.ListConstruct %int4_7543, %6060, %int4096_7544 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6171 = torch.aten.view %6169, %6170 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6171, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int2_7516 = torch.constant.int 2 + %6161 = torch.aten.transpose.int %6160, %int1_7515, %int2_7516 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6161, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_7517 = torch.constant.int 5 + %6162 = torch.prims.convert_element_type %6161, %int5_7517 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6162, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %6163 = torch.prim.ListConstruct %6156 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_7518 = torch.constant.bool false + %6164 = torch.aten.index_put %6150, %6163, %6162, %false_7518 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6164, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_7519 = torch.constant.int 32 + %int2_7520 = torch.constant.int 2 + %int8_7521 = torch.constant.int 8 + %int32_7522 = torch.constant.int 32 + %int128_7523 = torch.constant.int 128 + %6165 = torch.prim.ListConstruct %297, %int32_7519, %int2_7520, %int8_7521, %int32_7522, %int128_7523 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6166 = torch.aten.view %6164, %6165 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6166, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_7524 = torch.constant.int 2097152 + %6167 = torch.prim.ListConstruct %297, %int2097152_7524 : (!torch.int, !torch.int) -> !torch.list + %6168 = torch.aten.view %6166, %6167 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6168, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_7525 = torch.constant.int -2 + %6169 = torch.aten.unsqueeze %6124, %int-2_7525 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6169, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_7526 = torch.constant.int 4 + %int8_7527 = torch.constant.int 8 + %int4_7528 = torch.constant.int 4 + %int128_7529 = torch.constant.int 128 + %6170 = torch.prim.ListConstruct %int4_7526, %298, %int8_7527, %int4_7528, %int128_7529 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_7530 = torch.constant.bool false + %6171 = torch.aten.expand %6169, %6170, %false_7530 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6171, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_7531 = torch.constant.int 0 + %6172 = torch.aten.clone %6171, %int0_7531 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6172, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_7532 = torch.constant.int 4 + %int32_7533 = torch.constant.int 32 + %int128_7534 = torch.constant.int 128 + %6173 = torch.prim.ListConstruct %int4_7532, %298, %int32_7533, %int128_7534 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6174 = torch.aten._unsafe_view %6172, %6173 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6174, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_7535 = torch.constant.int -2 + %6175 = torch.aten.unsqueeze %5998, %int-2_7535 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6175, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_7536 = torch.constant.int 4 + %int8_7537 = torch.constant.int 8 + %int4_7538 = torch.constant.int 4 + %int128_7539 = torch.constant.int 128 + %6176 = torch.prim.ListConstruct %int4_7536, %298, %int8_7537, %int4_7538, %int128_7539 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_7540 = torch.constant.bool false + %6177 = torch.aten.expand %6175, %6176, %false_7540 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6177, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_7541 = torch.constant.int 0 + %6178 = torch.aten.clone %6177, %int0_7541 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6178, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_7542 = torch.constant.int 4 + %int32_7543 = torch.constant.int 32 + %int128_7544 = torch.constant.int 128 + %6179 = torch.prim.ListConstruct %int4_7542, %298, %int32_7543, %int128_7544 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6180 = torch.aten._unsafe_view %6178, %6179 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6180, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int1_7545 = torch.constant.int 1 - %6172 = torch.aten.add.Tensor %6010, %6171, %int1_7545 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6172, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_7546 = torch.constant.int 6 - %6173 = torch.prims.convert_element_type %6172, %int6_7546 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6173, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_7547 = torch.constant.int 2 - %6174 = torch.aten.pow.Tensor_Scalar %6173, %int2_7547 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6174, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_7548 = torch.constant.int -1 - %6175 = torch.prim.ListConstruct %int-1_7548 : (!torch.int) -> !torch.list - %true_7549 = torch.constant.bool true - %none_7550 = torch.constant.none - %6176 = torch.aten.mean.dim %6174, %6175, %true_7549, %none_7550 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6176, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_7551 = torch.constant.float 9.9999997473787516E-6 - %int1_7552 = torch.constant.int 1 - %6177 = torch.aten.add.Scalar %6176, %float9.999990e-06_7551, %int1_7552 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6177, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %6178 = torch.aten.rsqrt %6177 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6178, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %6179 = torch.aten.mul.Tensor %6173, %6178 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6179, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7553 = torch.constant.int 5 - %6180 = torch.prims.convert_element_type %6179, %int5_7553 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6180, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %6181 = torch.aten.mul.Tensor %267, %6180 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6181, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7554 = torch.constant.int 5 - %6182 = torch.prims.convert_element_type %6181, %int5_7554 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6182, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7555 = torch.constant.int -2 - %int-1_7556 = torch.constant.int -1 - %6183 = torch.aten.transpose.int %268, %int-2_7555, %int-1_7556 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_7557 = torch.constant.int 4 - %6184 = torch.aten.mul.int %int4_7557, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7558 = torch.constant.int 4096 - %6185 = torch.prim.ListConstruct %6184, %int4096_7558 : (!torch.int, !torch.int) -> !torch.list - %6186 = torch.aten.view %6182, %6185 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6186, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6187 = torch.aten.mm %6186, %6183 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %6187, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_7559 = torch.constant.int 4 - %int14336_7560 = torch.constant.int 14336 - %6188 = torch.prim.ListConstruct %int4_7559, %306, %int14336_7560 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6189 = torch.aten.view %6187, %6188 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %6189, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %6190 = torch.aten.silu %6189 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %6190, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_7561 = torch.constant.int -2 - %int-1_7562 = torch.constant.int -1 - %6191 = torch.aten.transpose.int %269, %int-2_7561, %int-1_7562 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_7563 = torch.constant.int 4 - %6192 = torch.aten.mul.int %int4_7563, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7564 = torch.constant.int 4096 - %6193 = torch.prim.ListConstruct %6192, %int4096_7564 : (!torch.int, !torch.int) -> !torch.list - %6194 = torch.aten.view %6182, %6193 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6194, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6195 = torch.aten.mm %6194, %6191 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %6195, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_7565 = torch.constant.int 4 - %int14336_7566 = torch.constant.int 14336 - %6196 = torch.prim.ListConstruct %int4_7565, %306, %int14336_7566 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6197 = torch.aten.view %6195, %6196 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %6197, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %6198 = torch.aten.mul.Tensor %6190, %6197 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %6198, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_7567 = torch.constant.int -2 - %int-1_7568 = torch.constant.int -1 - %6199 = torch.aten.transpose.int %270, %int-2_7567, %int-1_7568 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_7569 = torch.constant.int 1 - %6200 = torch.aten.size.int %6189, %int1_7569 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_7570 = torch.constant.int 4 - %6201 = torch.aten.mul.int %int4_7570, %6200 : !torch.int, !torch.int -> !torch.int - %int14336_7571 = torch.constant.int 14336 - %6202 = torch.prim.ListConstruct %6201, %int14336_7571 : (!torch.int, !torch.int) -> !torch.list - %6203 = torch.aten.view %6198, %6202 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %6203, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %6204 = torch.aten.mm %6203, %6199 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6204, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_7572 = torch.constant.int 4 - %int4096_7573 = torch.constant.int 4096 - %6205 = torch.prim.ListConstruct %int4_7572, %6200, %int4096_7573 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6206 = torch.aten.view %6204, %6205 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6206, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_7574 = torch.constant.int 1 - %6207 = torch.aten.add.Tensor %6172, %6206, %int1_7574 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6207, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_7575 = torch.constant.int 6 - %6208 = torch.prims.convert_element_type %6207, %int6_7575 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6208, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_7576 = torch.constant.int 2 - %6209 = torch.aten.pow.Tensor_Scalar %6208, %int2_7576 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6209, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_7577 = torch.constant.int -1 - %6210 = torch.prim.ListConstruct %int-1_7577 : (!torch.int) -> !torch.list - %true_7578 = torch.constant.bool true - %none_7579 = torch.constant.none - %6211 = torch.aten.mean.dim %6209, %6210, %true_7578, %none_7579 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6211, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_7580 = torch.constant.float 9.9999997473787516E-6 - %int1_7581 = torch.constant.int 1 - %6212 = torch.aten.add.Scalar %6211, %float9.999990e-06_7580, %int1_7581 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6212, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %6213 = torch.aten.rsqrt %6212 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6213, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %6214 = torch.aten.mul.Tensor %6208, %6213 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6214, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_7546 = torch.constant.int 2 + %6181 = torch.aten.transpose.int %6061, %int1_7545, %int2_7546 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6181, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_7547 = torch.constant.int 1 + %int2_7548 = torch.constant.int 2 + %6182 = torch.aten.transpose.int %6174, %int1_7547, %int2_7548 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6182, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_7549 = torch.constant.int 1 + %int2_7550 = torch.constant.int 2 + %6183 = torch.aten.transpose.int %6180, %int1_7549, %int2_7550 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6183, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_7551 = torch.constant.float 0.000000e+00 + %false_7552 = torch.constant.bool false + %none_7553 = torch.constant.none + %6184:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6181, %6182, %6183, %float0.000000e00_7551, %false_7552, %327, %none_7553) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %6184#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_7554 = torch.constant.int 1 + %int2_7555 = torch.constant.int 2 + %6185 = torch.aten.transpose.int %6184#0, %int1_7554, %int2_7555 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6185, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_7556 = torch.constant.int 4 + %int4096_7557 = torch.constant.int 4096 + %6186 = torch.prim.ListConstruct %int4_7556, %298, %int4096_7557 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6187 = torch.aten.view %6185, %6186 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6187, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_7558 = torch.constant.int -2 + %int-1_7559 = torch.constant.int -1 + %6188 = torch.aten.transpose.int %195, %int-2_7558, %int-1_7559 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_7560 = torch.constant.int 5 + %6189 = torch.prims.convert_element_type %6188, %int5_7560 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_7561 = torch.constant.int 4096 + %6190 = torch.prim.ListConstruct %342, %int4096_7561 : (!torch.int, !torch.int) -> !torch.list + %6191 = torch.aten.view %6187, %6190 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6191, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6192 = torch.aten.mm %6191, %6189 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6192, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_7562 = torch.constant.int 4 + %int4096_7563 = torch.constant.int 4096 + %6193 = torch.prim.ListConstruct %int4_7562, %298, %int4096_7563 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6194 = torch.aten.view %6192, %6193 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6194, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_7564 = torch.constant.int 1 + %6195 = torch.aten.add.Tensor %5961, %6194, %int1_7564 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6195, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_7565 = torch.constant.int 6 + %6196 = torch.prims.convert_element_type %6195, %int6_7565 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6196, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_7566 = torch.constant.int 2 + %6197 = torch.aten.pow.Tensor_Scalar %6196, %int2_7566 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6197, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_7567 = torch.constant.int -1 + %6198 = torch.prim.ListConstruct %int-1_7567 : (!torch.int) -> !torch.list + %true_7568 = torch.constant.bool true + %none_7569 = torch.constant.none + %6199 = torch.aten.mean.dim %6197, %6198, %true_7568, %none_7569 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6199, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_7570 = torch.constant.float 9.9999997473787516E-6 + %int1_7571 = torch.constant.int 1 + %6200 = torch.aten.add.Scalar %6199, %float9.999990e-06_7570, %int1_7571 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6200, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %6201 = torch.aten.rsqrt %6200 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6201, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %6202 = torch.aten.mul.Tensor %6196, %6201 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6202, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_7572 = torch.constant.int 5 + %6203 = torch.prims.convert_element_type %6202, %int5_7572 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6203, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %6204 = torch.aten.mul.Tensor %196, %6203 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6204, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_7573 = torch.constant.int 5 + %6205 = torch.prims.convert_element_type %6204, %int5_7573 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6205, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_7574 = torch.constant.int -2 + %int-1_7575 = torch.constant.int -1 + %6206 = torch.aten.transpose.int %197, %int-2_7574, %int-1_7575 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_7576 = torch.constant.int 5 + %6207 = torch.prims.convert_element_type %6206, %int5_7576 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_7577 = torch.constant.int 4096 + %6208 = torch.prim.ListConstruct %342, %int4096_7577 : (!torch.int, !torch.int) -> !torch.list + %6209 = torch.aten.view %6205, %6208 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6209, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6210 = torch.aten.mm %6209, %6207 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %6210, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_7578 = torch.constant.int 4 + %int14336_7579 = torch.constant.int 14336 + %6211 = torch.prim.ListConstruct %int4_7578, %298, %int14336_7579 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6212 = torch.aten.view %6210, %6211 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %6212, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %6213 = torch.aten.silu %6212 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %6213, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_7580 = torch.constant.int -2 + %int-1_7581 = torch.constant.int -1 + %6214 = torch.aten.transpose.int %198, %int-2_7580, %int-1_7581 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> %int5_7582 = torch.constant.int 5 - %6215 = torch.prims.convert_element_type %6214, %int5_7582 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6215, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %6216 = torch.aten.mul.Tensor %271, %6215 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6216, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7583 = torch.constant.int 5 - %6217 = torch.prims.convert_element_type %6216, %int5_7583 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6217, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7584 = torch.constant.int -2 - %int-1_7585 = torch.constant.int -1 - %6218 = torch.aten.transpose.int %272, %int-2_7584, %int-1_7585 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7586 = torch.constant.int 4 - %6219 = torch.aten.mul.int %int4_7586, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7587 = torch.constant.int 4096 - %6220 = torch.prim.ListConstruct %6219, %int4096_7587 : (!torch.int, !torch.int) -> !torch.list - %6221 = torch.aten.view %6217, %6220 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6221, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6222 = torch.aten.mm %6221, %6218 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6222, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_7588 = torch.constant.int 4 - %int4096_7589 = torch.constant.int 4096 - %6223 = torch.prim.ListConstruct %int4_7588, %306, %int4096_7589 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6224 = torch.aten.view %6222, %6223 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6224, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7590 = torch.constant.int -2 - %int-1_7591 = torch.constant.int -1 - %6225 = torch.aten.transpose.int %273, %int-2_7590, %int-1_7591 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_7592 = torch.constant.int 4 - %6226 = torch.aten.mul.int %int4_7592, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7593 = torch.constant.int 4096 - %6227 = torch.prim.ListConstruct %6226, %int4096_7593 : (!torch.int, !torch.int) -> !torch.list - %6228 = torch.aten.view %6217, %6227 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6228, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6229 = torch.aten.mm %6228, %6225 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %6229, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_7594 = torch.constant.int 4 - %int1024_7595 = torch.constant.int 1024 - %6230 = torch.prim.ListConstruct %int4_7594, %306, %int1024_7595 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6231 = torch.aten.view %6229, %6230 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %6231, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_7596 = torch.constant.int -2 - %int-1_7597 = torch.constant.int -1 - %6232 = torch.aten.transpose.int %274, %int-2_7596, %int-1_7597 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_7598 = torch.constant.int 4 - %6233 = torch.aten.mul.int %int4_7598, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7599 = torch.constant.int 4096 - %6234 = torch.prim.ListConstruct %6233, %int4096_7599 : (!torch.int, !torch.int) -> !torch.list - %6235 = torch.aten.view %6217, %6234 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6235, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6236 = torch.aten.mm %6235, %6232 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %6236, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_7600 = torch.constant.int 4 - %int1024_7601 = torch.constant.int 1024 - %6237 = torch.prim.ListConstruct %int4_7600, %306, %int1024_7601 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6238 = torch.aten.view %6236, %6237 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %6238, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_7602 = torch.constant.int 4 - %int32_7603 = torch.constant.int 32 - %int128_7604 = torch.constant.int 128 - %6239 = torch.prim.ListConstruct %int4_7602, %306, %int32_7603, %int128_7604 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6240 = torch.aten.view %6224, %6239 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6240, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_7605 = torch.constant.int 4 - %int8_7606 = torch.constant.int 8 - %int128_7607 = torch.constant.int 128 - %6241 = torch.prim.ListConstruct %int4_7605, %306, %int8_7606, %int128_7607 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6242 = torch.aten.view %6231, %6241 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6242, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_7608 = torch.constant.int 4 - %int8_7609 = torch.constant.int 8 - %int128_7610 = torch.constant.int 128 - %6243 = torch.prim.ListConstruct %int4_7608, %306, %int8_7609, %int128_7610 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6244 = torch.aten.view %6238, %6243 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6244, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_7611 = torch.constant.int 131072 - %none_7612 = torch.constant.none - %none_7613 = torch.constant.none - %cpu_7614 = torch.constant.device "cpu" - %false_7615 = torch.constant.bool false - %6245 = torch.aten.arange %int131072_7611, %none_7612, %none_7613, %cpu_7614, %false_7615 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_7616 = torch.constant.int 0 - %int128_7617 = torch.constant.int 128 - %none_7618 = torch.constant.none - %none_7619 = torch.constant.none - %cpu_7620 = torch.constant.device "cpu" - %false_7621 = torch.constant.bool false - %6246 = torch.aten.arange.start %int0_7616, %int128_7617, %none_7618, %none_7619, %cpu_7620, %false_7621 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_7622 = torch.constant.int 2 - %6247 = torch.aten.floor_divide.Scalar %6246, %int2_7622 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_7623 = torch.constant.int 6 - %6248 = torch.prims.convert_element_type %6247, %int6_7623 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_7624 = torch.constant.int 128 - %6249 = torch.aten.div.Scalar %6248, %int128_7624 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_7625 = torch.constant.float 2.000000e+00 - %6250 = torch.aten.mul.Scalar %6249, %float2.000000e00_7625 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_7626 = torch.constant.float 5.000000e+05 - %6251 = torch.aten.pow.Scalar %float5.000000e05_7626, %6250 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %6252 = torch.aten.reciprocal %6251 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_7627 = torch.constant.float 1.000000e+00 - %6253 = torch.aten.mul.Scalar %6252, %float1.000000e00_7627 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_7628 = torch.constant.int 1 - %6254 = torch.aten.unsqueeze %6245, %int1_7628 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_7629 = torch.constant.int 0 - %6255 = torch.aten.unsqueeze %6253, %int0_7629 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %6256 = torch.aten.mul.Tensor %6254, %6255 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_7630 = torch.constant.int 1 - %6257 = torch.aten.size.int %6224, %int1_7630 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int - %int0_7631 = torch.constant.int 0 - %6258 = torch.aten.add.int %int0_7631, %6257 : !torch.int, !torch.int -> !torch.int - %int0_7632 = torch.constant.int 0 - %int0_7633 = torch.constant.int 0 - %int1_7634 = torch.constant.int 1 - %6259 = torch.aten.slice.Tensor %6256, %int0_7632, %int0_7633, %6258, %int1_7634 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6259, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_7635 = torch.constant.int 1 - %int0_7636 = torch.constant.int 0 - %int9223372036854775807_7637 = torch.constant.int 9223372036854775807 - %int1_7638 = torch.constant.int 1 - %6260 = torch.aten.slice.Tensor %6259, %int1_7635, %int0_7636, %int9223372036854775807_7637, %int1_7638 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6260, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_7639 = torch.constant.int 1 - %int0_7640 = torch.constant.int 0 - %int9223372036854775807_7641 = torch.constant.int 9223372036854775807 - %int1_7642 = torch.constant.int 1 - %6261 = torch.aten.slice.Tensor %6260, %int1_7639, %int0_7640, %int9223372036854775807_7641, %int1_7642 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6261, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_7643 = torch.constant.int 0 - %6262 = torch.aten.unsqueeze %6261, %int0_7643 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6262, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_7644 = torch.constant.int 1 - %int0_7645 = torch.constant.int 0 - %int9223372036854775807_7646 = torch.constant.int 9223372036854775807 - %int1_7647 = torch.constant.int 1 - %6263 = torch.aten.slice.Tensor %6262, %int1_7644, %int0_7645, %int9223372036854775807_7646, %int1_7647 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6263, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_7648 = torch.constant.int 2 - %int0_7649 = torch.constant.int 0 - %int9223372036854775807_7650 = torch.constant.int 9223372036854775807 - %int1_7651 = torch.constant.int 1 - %6264 = torch.aten.slice.Tensor %6263, %int2_7648, %int0_7649, %int9223372036854775807_7650, %int1_7651 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6264, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_7652 = torch.constant.int 4 + %6215 = torch.prims.convert_element_type %6214, %int5_7582 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_7583 = torch.constant.int 4096 + %6216 = torch.prim.ListConstruct %342, %int4096_7583 : (!torch.int, !torch.int) -> !torch.list + %6217 = torch.aten.view %6205, %6216 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6217, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6218 = torch.aten.mm %6217, %6215 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %6218, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_7584 = torch.constant.int 4 + %int14336_7585 = torch.constant.int 14336 + %6219 = torch.prim.ListConstruct %int4_7584, %298, %int14336_7585 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6220 = torch.aten.view %6218, %6219 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %6220, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %6221 = torch.aten.mul.Tensor %6213, %6220 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %6221, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_7586 = torch.constant.int -2 + %int-1_7587 = torch.constant.int -1 + %6222 = torch.aten.transpose.int %199, %int-2_7586, %int-1_7587 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_7588 = torch.constant.int 5 + %6223 = torch.prims.convert_element_type %6222, %int5_7588 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_7589 = torch.constant.int 14336 + %6224 = torch.prim.ListConstruct %342, %int14336_7589 : (!torch.int, !torch.int) -> !torch.list + %6225 = torch.aten.view %6221, %6224 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %6225, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %6226 = torch.aten.mm %6225, %6223 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6226, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_7590 = torch.constant.int 4 + %int4096_7591 = torch.constant.int 4096 + %6227 = torch.prim.ListConstruct %int4_7590, %298, %int4096_7591 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6228 = torch.aten.view %6226, %6227 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6228, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_7592 = torch.constant.int 1 + %6229 = torch.aten.add.Tensor %6195, %6228, %int1_7592 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6229, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_7593 = torch.constant.int 6 + %6230 = torch.prims.convert_element_type %6229, %int6_7593 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6230, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_7594 = torch.constant.int 2 + %6231 = torch.aten.pow.Tensor_Scalar %6230, %int2_7594 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6231, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_7595 = torch.constant.int -1 + %6232 = torch.prim.ListConstruct %int-1_7595 : (!torch.int) -> !torch.list + %true_7596 = torch.constant.bool true + %none_7597 = torch.constant.none + %6233 = torch.aten.mean.dim %6231, %6232, %true_7596, %none_7597 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6233, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_7598 = torch.constant.float 9.9999997473787516E-6 + %int1_7599 = torch.constant.int 1 + %6234 = torch.aten.add.Scalar %6233, %float9.999990e-06_7598, %int1_7599 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6234, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %6235 = torch.aten.rsqrt %6234 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6235, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %6236 = torch.aten.mul.Tensor %6230, %6235 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6236, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_7600 = torch.constant.int 5 + %6237 = torch.prims.convert_element_type %6236, %int5_7600 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6237, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %6238 = torch.aten.mul.Tensor %200, %6237 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6238, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_7601 = torch.constant.int 5 + %6239 = torch.prims.convert_element_type %6238, %int5_7601 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6239, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_7602 = torch.constant.int -2 + %int-1_7603 = torch.constant.int -1 + %6240 = torch.aten.transpose.int %201, %int-2_7602, %int-1_7603 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_7604 = torch.constant.int 5 + %6241 = torch.prims.convert_element_type %6240, %int5_7604 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_7605 = torch.constant.int 4096 + %6242 = torch.prim.ListConstruct %342, %int4096_7605 : (!torch.int, !torch.int) -> !torch.list + %6243 = torch.aten.view %6239, %6242 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6243, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6244 = torch.aten.mm %6243, %6241 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6244, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_7606 = torch.constant.int 4 + %int4096_7607 = torch.constant.int 4096 + %6245 = torch.prim.ListConstruct %int4_7606, %298, %int4096_7607 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6246 = torch.aten.view %6244, %6245 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6246, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_7608 = torch.constant.int -2 + %int-1_7609 = torch.constant.int -1 + %6247 = torch.aten.transpose.int %202, %int-2_7608, %int-1_7609 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_7610 = torch.constant.int 5 + %6248 = torch.prims.convert_element_type %6247, %int5_7610 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_7611 = torch.constant.int 4096 + %6249 = torch.prim.ListConstruct %342, %int4096_7611 : (!torch.int, !torch.int) -> !torch.list + %6250 = torch.aten.view %6239, %6249 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6250, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6251 = torch.aten.mm %6250, %6248 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %6251, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_7612 = torch.constant.int 4 + %int1024_7613 = torch.constant.int 1024 + %6252 = torch.prim.ListConstruct %int4_7612, %298, %int1024_7613 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6253 = torch.aten.view %6251, %6252 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %6253, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_7614 = torch.constant.int -2 + %int-1_7615 = torch.constant.int -1 + %6254 = torch.aten.transpose.int %203, %int-2_7614, %int-1_7615 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_7616 = torch.constant.int 5 + %6255 = torch.prims.convert_element_type %6254, %int5_7616 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_7617 = torch.constant.int 4096 + %6256 = torch.prim.ListConstruct %342, %int4096_7617 : (!torch.int, !torch.int) -> !torch.list + %6257 = torch.aten.view %6239, %6256 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6257, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6258 = torch.aten.mm %6257, %6255 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %6258, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_7618 = torch.constant.int 4 + %int1024_7619 = torch.constant.int 1024 + %6259 = torch.prim.ListConstruct %int4_7618, %298, %int1024_7619 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6260 = torch.aten.view %6258, %6259 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %6260, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_7620 = torch.constant.int 4 + %int32_7621 = torch.constant.int 32 + %int128_7622 = torch.constant.int 128 + %6261 = torch.prim.ListConstruct %int4_7620, %298, %int32_7621, %int128_7622 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6262 = torch.aten.view %6246, %6261 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6262, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_7623 = torch.constant.int 4 + %int8_7624 = torch.constant.int 8 + %int128_7625 = torch.constant.int 128 + %6263 = torch.prim.ListConstruct %int4_7623, %298, %int8_7624, %int128_7625 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6264 = torch.aten.view %6253, %6263 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6264, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_7626 = torch.constant.int 4 + %int8_7627 = torch.constant.int 8 + %int128_7628 = torch.constant.int 128 + %6265 = torch.prim.ListConstruct %int4_7626, %298, %int8_7627, %int128_7628 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6266 = torch.aten.view %6260, %6265 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6266, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_7629 = torch.constant.int 131072 + %none_7630 = torch.constant.none + %none_7631 = torch.constant.none + %cpu_7632 = torch.constant.device "cpu" + %false_7633 = torch.constant.bool false + %6267 = torch.aten.arange %int131072_7629, %none_7630, %none_7631, %cpu_7632, %false_7633 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_7634 = torch.constant.int 0 + %int128_7635 = torch.constant.int 128 + %int2_7636 = torch.constant.int 2 + %int4_7637 = torch.constant.int 4 + %none_7638 = torch.constant.none + %cpu_7639 = torch.constant.device "cpu" + %false_7640 = torch.constant.bool false + %6268 = torch.aten.arange.start_step %int0_7634, %int128_7635, %int2_7636, %int4_7637, %none_7638, %cpu_7639, %false_7640 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_7641 = torch.constant.int 6 + %6269 = torch.prims.convert_element_type %6268, %int6_7641 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_7642 = torch.constant.int 128 + %6270 = torch.aten.div.Scalar %6269, %int128_7642 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_7643 = torch.constant.float 5.000000e+05 + %6271 = torch.aten.pow.Scalar %float5.000000e05_7643, %6270 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6272 = torch.aten.reciprocal %6271 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_7644 = torch.constant.float 1.000000e+00 + %6273 = torch.aten.mul.Scalar %6272, %float1.000000e00_7644 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %6274 = torch.aten.reciprocal %6273 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_7645 = torch.constant.float 6.2831853071795862 + %6275 = torch.aten.mul.Scalar %6274, %float6.283190e00_7645 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_7646 = torch.constant.float 8.192000e+03 + %6276 = torch.aten.gt.Scalar %6275, %float8.192000e03_7646 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_7647 = torch.constant.int 8 + %6277 = torch.aten.div.Scalar %6273, %int8_7647 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6278 = torch.aten.where.self %6276, %6277, %6273 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6279 = torch.aten.reciprocal %6275 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_7648 = torch.constant.int 8192 + %6280 = torch.aten.mul.Scalar %6279, %int8192_7648 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_7649 = torch.constant.int 1 + %int1_7650 = torch.constant.int 1 + %6281 = torch.aten.sub.Scalar %6280, %int1_7649, %int1_7650 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_7651 = torch.constant.int 3 + %6282 = torch.aten.div.Scalar %6281, %int3_7651 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_7652 = torch.constant.int 1 %int1_7653 = torch.constant.int 1 - %int1_7654 = torch.constant.int 1 - %6265 = torch.prim.ListConstruct %int4_7652, %int1_7653, %int1_7654 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6266 = torch.aten.repeat %6264, %6265 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %6266, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_7655 = torch.constant.int 6 - %6267 = torch.prims.convert_element_type %6240, %int6_7655 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %6267, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %6268 = torch_c.to_builtin_tensor %6267 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %6269 = torch_c.to_builtin_tensor %6266 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %6270 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%6268, %6269) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %6271 = torch_c.from_builtin_tensor %6270 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %6271, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_7656 = torch.constant.int 5 - %6272 = torch.prims.convert_element_type %6271, %int5_7656 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6272, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_7657 = torch.constant.int 131072 - %none_7658 = torch.constant.none - %none_7659 = torch.constant.none - %cpu_7660 = torch.constant.device "cpu" - %false_7661 = torch.constant.bool false - %6273 = torch.aten.arange %int131072_7657, %none_7658, %none_7659, %cpu_7660, %false_7661 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %6283 = torch.aten.rsub.Scalar %6282, %int1_7652, %int1_7653 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %6284 = torch.aten.mul.Tensor %6283, %6278 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_7654 = torch.constant.int 8 + %6285 = torch.aten.div.Scalar %6284, %int8_7654 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6286 = torch.aten.mul.Tensor %6282, %6278 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_7655 = torch.constant.int 1 + %6287 = torch.aten.add.Tensor %6285, %6286, %int1_7655 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_7656 = torch.constant.float 2.048000e+03 + %6288 = torch.aten.lt.Scalar %6275, %float2.048000e03_7656 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6289 = torch.aten.bitwise_not %6288 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_7657 = torch.constant.float 8.192000e+03 + %6290 = torch.aten.gt.Scalar %6275, %float8.192000e03_7657 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6291 = torch.aten.bitwise_not %6290 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6292 = torch.aten.mul.Tensor %6289, %6291 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6293 = torch.aten.where.self %6292, %6287, %6278 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6294 = torch.prim.ListConstruct %6293, %6293 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_7658 = torch.constant.int -1 + %6295 = torch.aten.cat %6294, %int-1_7658 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_7659 = torch.constant.int 6 + %6296 = torch.prims.convert_element_type %6295, %int6_7659 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_7660 = torch.constant.int 1 + %6297 = torch.aten.unsqueeze %6267, %int1_7660 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_7661 = torch.constant.int 6 + %6298 = torch.prims.convert_element_type %6297, %int6_7661 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> %int0_7662 = torch.constant.int 0 - %int128_7663 = torch.constant.int 128 - %none_7664 = torch.constant.none - %none_7665 = torch.constant.none - %cpu_7666 = torch.constant.device "cpu" - %false_7667 = torch.constant.bool false - %6274 = torch.aten.arange.start %int0_7662, %int128_7663, %none_7664, %none_7665, %cpu_7666, %false_7667 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_7668 = torch.constant.int 2 - %6275 = torch.aten.floor_divide.Scalar %6274, %int2_7668 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_7669 = torch.constant.int 6 - %6276 = torch.prims.convert_element_type %6275, %int6_7669 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_7670 = torch.constant.int 128 - %6277 = torch.aten.div.Scalar %6276, %int128_7670 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_7671 = torch.constant.float 2.000000e+00 - %6278 = torch.aten.mul.Scalar %6277, %float2.000000e00_7671 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_7672 = torch.constant.float 5.000000e+05 - %6279 = torch.aten.pow.Scalar %float5.000000e05_7672, %6278 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %6280 = torch.aten.reciprocal %6279 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_7673 = torch.constant.float 1.000000e+00 - %6281 = torch.aten.mul.Scalar %6280, %float1.000000e00_7673 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_7674 = torch.constant.int 1 - %6282 = torch.aten.unsqueeze %6273, %int1_7674 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_7675 = torch.constant.int 0 - %6283 = torch.aten.unsqueeze %6281, %int0_7675 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %6284 = torch.aten.mul.Tensor %6282, %6283 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %6299 = torch.aten.unsqueeze %6296, %int0_7662 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_7663 = torch.constant.int 6 + %6300 = torch.prims.convert_element_type %6299, %int6_7663 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %6301 = torch.aten.mul.Tensor %6298, %6300 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %6302 = torch.aten.cos %6301 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_7664 = torch.constant.int 5 + %6303 = torch.prims.convert_element_type %6302, %int5_7664 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %6304 = torch.aten.sin %6301 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_7665 = torch.constant.int 5 + %6305 = torch.prims.convert_element_type %6304, %int5_7665 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_7666 = torch.constant.int 0 + %int0_7667 = torch.constant.int 0 + %int1_7668 = torch.constant.int 1 + %6306 = torch.aten.slice.Tensor %6303, %int0_7666, %int0_7667, %298, %int1_7668 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6306, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_7669 = torch.constant.int 1 + %int0_7670 = torch.constant.int 0 + %int9223372036854775807_7671 = torch.constant.int 9223372036854775807 + %int1_7672 = torch.constant.int 1 + %6307 = torch.aten.slice.Tensor %6306, %int1_7669, %int0_7670, %int9223372036854775807_7671, %int1_7672 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6307, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_7673 = torch.constant.int 0 + %int0_7674 = torch.constant.int 0 + %int1_7675 = torch.constant.int 1 + %6308 = torch.aten.slice.Tensor %6305, %int0_7673, %int0_7674, %298, %int1_7675 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6308, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int1_7676 = torch.constant.int 1 - %6285 = torch.aten.size.int %6231, %int1_7676 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int %int0_7677 = torch.constant.int 0 - %6286 = torch.aten.add.int %int0_7677, %6285 : !torch.int, !torch.int -> !torch.int - %int0_7678 = torch.constant.int 0 - %int0_7679 = torch.constant.int 0 - %int1_7680 = torch.constant.int 1 - %6287 = torch.aten.slice.Tensor %6284, %int0_7678, %int0_7679, %6286, %int1_7680 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6287, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %int9223372036854775807_7678 = torch.constant.int 9223372036854775807 + %int1_7679 = torch.constant.int 1 + %6309 = torch.aten.slice.Tensor %6308, %int1_7676, %int0_7677, %int9223372036854775807_7678, %int1_7679 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6309, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_7680 = torch.constant.int 0 + %6310 = torch.aten.unsqueeze %6307, %int0_7680 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6310, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int1_7681 = torch.constant.int 1 %int0_7682 = torch.constant.int 0 %int9223372036854775807_7683 = torch.constant.int 9223372036854775807 %int1_7684 = torch.constant.int 1 - %6288 = torch.aten.slice.Tensor %6287, %int1_7681, %int0_7682, %int9223372036854775807_7683, %int1_7684 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6288, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_7685 = torch.constant.int 1 - %int0_7686 = torch.constant.int 0 - %int9223372036854775807_7687 = torch.constant.int 9223372036854775807 - %int1_7688 = torch.constant.int 1 - %6289 = torch.aten.slice.Tensor %6288, %int1_7685, %int0_7686, %int9223372036854775807_7687, %int1_7688 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6289, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_7689 = torch.constant.int 0 - %6290 = torch.aten.unsqueeze %6289, %int0_7689 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6290, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_7690 = torch.constant.int 1 - %int0_7691 = torch.constant.int 0 - %int9223372036854775807_7692 = torch.constant.int 9223372036854775807 + %6311 = torch.aten.slice.Tensor %6310, %int1_7681, %int0_7682, %int9223372036854775807_7683, %int1_7684 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6311, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_7685 = torch.constant.int 2 + %6312 = torch.aten.unsqueeze %6311, %int2_7685 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6312, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_7686 = torch.constant.int 3 + %int0_7687 = torch.constant.int 0 + %int9223372036854775807_7688 = torch.constant.int 9223372036854775807 + %int1_7689 = torch.constant.int 1 + %6313 = torch.aten.slice.Tensor %6312, %int3_7686, %int0_7687, %int9223372036854775807_7688, %int1_7689 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6313, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_7690 = torch.constant.int 4 + %int1_7691 = torch.constant.int 1 + %int1_7692 = torch.constant.int 1 %int1_7693 = torch.constant.int 1 - %6291 = torch.aten.slice.Tensor %6290, %int1_7690, %int0_7691, %int9223372036854775807_7692, %int1_7693 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6291, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_7694 = torch.constant.int 2 - %int0_7695 = torch.constant.int 0 - %int9223372036854775807_7696 = torch.constant.int 9223372036854775807 - %int1_7697 = torch.constant.int 1 - %6292 = torch.aten.slice.Tensor %6291, %int2_7694, %int0_7695, %int9223372036854775807_7696, %int1_7697 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6292, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_7698 = torch.constant.int 4 - %int1_7699 = torch.constant.int 1 - %int1_7700 = torch.constant.int 1 - %6293 = torch.prim.ListConstruct %int4_7698, %int1_7699, %int1_7700 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6294 = torch.aten.repeat %6292, %6293 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %6294, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_7701 = torch.constant.int 6 - %6295 = torch.prims.convert_element_type %6242, %int6_7701 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %6295, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %6296 = torch_c.to_builtin_tensor %6295 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %6297 = torch_c.to_builtin_tensor %6294 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %6298 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%6296, %6297) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %6299 = torch_c.from_builtin_tensor %6298 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %6299, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_7702 = torch.constant.int 5 - %6300 = torch.prims.convert_element_type %6299, %int5_7702 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6300, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_7703 = torch.constant.int 64 - %6301 = torch.aten.mul.Scalar %arg2, %int64_7703 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6301, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int60 = torch.constant.int 60 - %int1_7704 = torch.constant.int 1 - %6302 = torch.aten.add.Scalar %6301, %int60, %int1_7704 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6302, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_7705 = torch.constant.int 4 - %int32_7706 = torch.constant.int 32 - %int8_7707 = torch.constant.int 8 - %int128_7708 = torch.constant.int 128 - %6303 = torch.prim.ListConstruct %int4_7705, %398, %int32_7706, %int8_7707, %int128_7708 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6304 = torch.aten.view %6300, %6303 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6304, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_7709 = torch.constant.int 4 - %6305 = torch.aten.mul.int %int4_7709, %398 : !torch.int, !torch.int -> !torch.int - %int32_7710 = torch.constant.int 32 - %int8_7711 = torch.constant.int 8 - %int128_7712 = torch.constant.int 128 - %6306 = torch.prim.ListConstruct %6305, %int32_7710, %int8_7711, %int128_7712 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6307 = torch.aten.view %6304, %6306 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6307, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_7713 = torch.constant.int 4 - %6308 = torch.aten.mul.int %int4_7713, %398 : !torch.int, !torch.int -> !torch.int - %6309 = torch.prim.ListConstruct %6308 : (!torch.int) -> !torch.list - %6310 = torch.aten.view %6302, %6309 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %6310, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_7714 = torch.constant.int 32 - %int2_7715 = torch.constant.int 2 - %int32_7716 = torch.constant.int 32 - %int8_7717 = torch.constant.int 8 - %int128_7718 = torch.constant.int 128 - %6311 = torch.prim.ListConstruct %389, %int32_7714, %int2_7715, %int32_7716, %int8_7717, %int128_7718 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6312 = torch.aten.view %6144, %6311 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6312, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_7719 = torch.constant.int 32 - %6313 = torch.aten.mul.int %389, %int32_7719 : !torch.int, !torch.int -> !torch.int - %int2_7720 = torch.constant.int 2 - %6314 = torch.aten.mul.int %6313, %int2_7720 : !torch.int, !torch.int -> !torch.int - %int32_7721 = torch.constant.int 32 - %int8_7722 = torch.constant.int 8 - %int128_7723 = torch.constant.int 128 - %6315 = torch.prim.ListConstruct %6314, %int32_7721, %int8_7722, %int128_7723 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6316 = torch.aten.view %6312, %6315 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6316, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %6317 = torch.prim.ListConstruct %6310 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_7724 = torch.constant.bool false - %6318 = torch.aten.index_put %6316, %6317, %6307, %false_7724 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6318, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_7725 = torch.constant.int 32 - %int2_7726 = torch.constant.int 2 - %int32_7727 = torch.constant.int 32 - %int8_7728 = torch.constant.int 8 - %int128_7729 = torch.constant.int 128 - %6319 = torch.prim.ListConstruct %389, %int32_7725, %int2_7726, %int32_7727, %int8_7728, %int128_7729 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6320 = torch.aten.view %6318, %6319 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6320, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_7730 = torch.constant.int 2097152 - %6321 = torch.prim.ListConstruct %389, %int2097152_7730 : (!torch.int, !torch.int) -> !torch.list - %6322 = torch.aten.view %6320, %6321 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %6322, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_7731 = torch.constant.int 32 - %int2_7732 = torch.constant.int 2 - %int32_7733 = torch.constant.int 32 - %int8_7734 = torch.constant.int 8 - %int128_7735 = torch.constant.int 128 - %6323 = torch.prim.ListConstruct %389, %int32_7731, %int2_7732, %int32_7733, %int8_7734, %int128_7735 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6324 = torch.aten.view %6322, %6323 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6324, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_7736 = torch.constant.int 32 - %int8_7737 = torch.constant.int 8 - %int128_7738 = torch.constant.int 128 - %6325 = torch.prim.ListConstruct %6314, %int32_7736, %int8_7737, %int128_7738 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6326 = torch.aten.view %6324, %6325 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6326, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_7739 = torch.constant.int 4 - %int32_7740 = torch.constant.int 32 - %int8_7741 = torch.constant.int 8 - %int128_7742 = torch.constant.int 128 - %6327 = torch.prim.ListConstruct %int4_7739, %398, %int32_7740, %int8_7741, %int128_7742 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6328 = torch.aten.view %6244, %6327 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6328, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_7743 = torch.constant.int 4 - %6329 = torch.aten.mul.int %int4_7743, %398 : !torch.int, !torch.int -> !torch.int - %int32_7744 = torch.constant.int 32 - %int8_7745 = torch.constant.int 8 - %int128_7746 = torch.constant.int 128 - %6330 = torch.prim.ListConstruct %6329, %int32_7744, %int8_7745, %int128_7746 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6331 = torch.aten.view %6328, %6330 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6331, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_7747 = torch.constant.int 1 - %int1_7748 = torch.constant.int 1 - %6332 = torch.aten.add.Scalar %6302, %int1_7747, %int1_7748 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6332, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_7749 = torch.constant.int 4 - %6333 = torch.aten.mul.int %int4_7749, %398 : !torch.int, !torch.int -> !torch.int - %6334 = torch.prim.ListConstruct %6333 : (!torch.int) -> !torch.list - %6335 = torch.aten.view %6332, %6334 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %6335, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %6336 = torch.prim.ListConstruct %6335 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_7750 = torch.constant.bool false - %6337 = torch.aten.index_put %6326, %6336, %6331, %false_7750 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6337, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_7751 = torch.constant.int 32 - %int2_7752 = torch.constant.int 2 - %int32_7753 = torch.constant.int 32 - %int8_7754 = torch.constant.int 8 - %int128_7755 = torch.constant.int 128 - %6338 = torch.prim.ListConstruct %389, %int32_7751, %int2_7752, %int32_7753, %int8_7754, %int128_7755 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6339 = torch.aten.view %6337, %6338 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6339, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_7756 = torch.constant.int 2097152 - %6340 = torch.prim.ListConstruct %389, %int2097152_7756 : (!torch.int, !torch.int) -> !torch.list - %6341 = torch.aten.view %6339, %6340 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %6341, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_7757 = torch.constant.int -2 - %6342 = torch.aten.unsqueeze %6300, %int-2_7757 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6342, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_7758 = torch.constant.int 4 - %int8_7759 = torch.constant.int 8 - %int4_7760 = torch.constant.int 4 - %int128_7761 = torch.constant.int 128 - %6343 = torch.prim.ListConstruct %int4_7758, %6285, %int8_7759, %int4_7760, %int128_7761 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7762 = torch.constant.bool false - %6344 = torch.aten.expand %6342, %6343, %false_7762 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6344, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %6314 = torch.prim.ListConstruct %int4_7690, %int1_7691, %int1_7692, %int1_7693 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6315 = torch.aten.repeat %6313, %6314 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6315, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_7694 = torch.constant.int 0 + %6316 = torch.aten.unsqueeze %6309, %int0_7694 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6316, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_7695 = torch.constant.int 1 + %int0_7696 = torch.constant.int 0 + %int9223372036854775807_7697 = torch.constant.int 9223372036854775807 + %int1_7698 = torch.constant.int 1 + %6317 = torch.aten.slice.Tensor %6316, %int1_7695, %int0_7696, %int9223372036854775807_7697, %int1_7698 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6317, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_7699 = torch.constant.int 2 + %6318 = torch.aten.unsqueeze %6317, %int2_7699 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6318, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_7700 = torch.constant.int 3 + %int0_7701 = torch.constant.int 0 + %int9223372036854775807_7702 = torch.constant.int 9223372036854775807 + %int1_7703 = torch.constant.int 1 + %6319 = torch.aten.slice.Tensor %6318, %int3_7700, %int0_7701, %int9223372036854775807_7702, %int1_7703 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6319, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_7704 = torch.constant.int 4 + %int1_7705 = torch.constant.int 1 + %int1_7706 = torch.constant.int 1 + %int1_7707 = torch.constant.int 1 + %6320 = torch.prim.ListConstruct %int4_7704, %int1_7705, %int1_7706, %int1_7707 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6321 = torch.aten.repeat %6319, %6320 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6321, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %6322 = torch.aten.mul.Tensor %6262, %6315 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6322, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_7708 = torch.constant.int 3 + %int0_7709 = torch.constant.int 0 + %int64_7710 = torch.constant.int 64 + %int1_7711 = torch.constant.int 1 + %6323 = torch.aten.slice.Tensor %6262, %int3_7708, %int0_7709, %int64_7710, %int1_7711 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %6323, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_7712 = torch.constant.int 3 + %int64_7713 = torch.constant.int 64 + %int9223372036854775807_7714 = torch.constant.int 9223372036854775807 + %int1_7715 = torch.constant.int 1 + %6324 = torch.aten.slice.Tensor %6262, %int3_7712, %int64_7713, %int9223372036854775807_7714, %int1_7715 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %6324, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %6325 = torch.aten.neg %6324 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %6325, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %6326 = torch.prim.ListConstruct %6325, %6323 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_7716 = torch.constant.int -1 + %6327 = torch.aten.cat %6326, %int-1_7716 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6327, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %6328 = torch.aten.mul.Tensor %6327, %6321 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6328, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_7717 = torch.constant.int 1 + %6329 = torch.aten.add.Tensor %6322, %6328, %int1_7717 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6329, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_7718 = torch.constant.int 131072 + %none_7719 = torch.constant.none + %none_7720 = torch.constant.none + %cpu_7721 = torch.constant.device "cpu" + %false_7722 = torch.constant.bool false + %6330 = torch.aten.arange %int131072_7718, %none_7719, %none_7720, %cpu_7721, %false_7722 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_7723 = torch.constant.int 0 + %int128_7724 = torch.constant.int 128 + %int2_7725 = torch.constant.int 2 + %int4_7726 = torch.constant.int 4 + %none_7727 = torch.constant.none + %cpu_7728 = torch.constant.device "cpu" + %false_7729 = torch.constant.bool false + %6331 = torch.aten.arange.start_step %int0_7723, %int128_7724, %int2_7725, %int4_7726, %none_7727, %cpu_7728, %false_7729 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_7730 = torch.constant.int 6 + %6332 = torch.prims.convert_element_type %6331, %int6_7730 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_7731 = torch.constant.int 128 + %6333 = torch.aten.div.Scalar %6332, %int128_7731 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_7732 = torch.constant.float 5.000000e+05 + %6334 = torch.aten.pow.Scalar %float5.000000e05_7732, %6333 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6335 = torch.aten.reciprocal %6334 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_7733 = torch.constant.float 1.000000e+00 + %6336 = torch.aten.mul.Scalar %6335, %float1.000000e00_7733 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %6337 = torch.aten.reciprocal %6336 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_7734 = torch.constant.float 6.2831853071795862 + %6338 = torch.aten.mul.Scalar %6337, %float6.283190e00_7734 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_7735 = torch.constant.float 8.192000e+03 + %6339 = torch.aten.gt.Scalar %6338, %float8.192000e03_7735 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_7736 = torch.constant.int 8 + %6340 = torch.aten.div.Scalar %6336, %int8_7736 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6341 = torch.aten.where.self %6339, %6340, %6336 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6342 = torch.aten.reciprocal %6338 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_7737 = torch.constant.int 8192 + %6343 = torch.aten.mul.Scalar %6342, %int8192_7737 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_7738 = torch.constant.int 1 + %int1_7739 = torch.constant.int 1 + %6344 = torch.aten.sub.Scalar %6343, %int1_7738, %int1_7739 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_7740 = torch.constant.int 3 + %6345 = torch.aten.div.Scalar %6344, %int3_7740 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_7741 = torch.constant.int 1 + %int1_7742 = torch.constant.int 1 + %6346 = torch.aten.rsub.Scalar %6345, %int1_7741, %int1_7742 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %6347 = torch.aten.mul.Tensor %6346, %6341 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_7743 = torch.constant.int 8 + %6348 = torch.aten.div.Scalar %6347, %int8_7743 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6349 = torch.aten.mul.Tensor %6345, %6341 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_7744 = torch.constant.int 1 + %6350 = torch.aten.add.Tensor %6348, %6349, %int1_7744 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_7745 = torch.constant.float 2.048000e+03 + %6351 = torch.aten.lt.Scalar %6338, %float2.048000e03_7745 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6352 = torch.aten.bitwise_not %6351 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_7746 = torch.constant.float 8.192000e+03 + %6353 = torch.aten.gt.Scalar %6338, %float8.192000e03_7746 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6354 = torch.aten.bitwise_not %6353 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6355 = torch.aten.mul.Tensor %6352, %6354 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6356 = torch.aten.where.self %6355, %6350, %6341 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6357 = torch.prim.ListConstruct %6356, %6356 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_7747 = torch.constant.int -1 + %6358 = torch.aten.cat %6357, %int-1_7747 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_7748 = torch.constant.int 6 + %6359 = torch.prims.convert_element_type %6358, %int6_7748 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_7749 = torch.constant.int 1 + %6360 = torch.aten.unsqueeze %6330, %int1_7749 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_7750 = torch.constant.int 6 + %6361 = torch.prims.convert_element_type %6360, %int6_7750 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_7751 = torch.constant.int 0 + %6362 = torch.aten.unsqueeze %6359, %int0_7751 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_7752 = torch.constant.int 6 + %6363 = torch.prims.convert_element_type %6362, %int6_7752 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %6364 = torch.aten.mul.Tensor %6361, %6363 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %6365 = torch.aten.cos %6364 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_7753 = torch.constant.int 5 + %6366 = torch.prims.convert_element_type %6365, %int5_7753 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %6367 = torch.aten.sin %6364 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_7754 = torch.constant.int 5 + %6368 = torch.prims.convert_element_type %6367, %int5_7754 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_7755 = torch.constant.int 0 + %int0_7756 = torch.constant.int 0 + %int1_7757 = torch.constant.int 1 + %6369 = torch.aten.slice.Tensor %6366, %int0_7755, %int0_7756, %298, %int1_7757 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6369, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_7758 = torch.constant.int 1 + %int0_7759 = torch.constant.int 0 + %int9223372036854775807_7760 = torch.constant.int 9223372036854775807 + %int1_7761 = torch.constant.int 1 + %6370 = torch.aten.slice.Tensor %6369, %int1_7758, %int0_7759, %int9223372036854775807_7760, %int1_7761 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6370, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_7762 = torch.constant.int 0 %int0_7763 = torch.constant.int 0 - %6345 = torch.aten.clone %6344, %int0_7763 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6345, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7764 = torch.constant.int 4 - %int32_7765 = torch.constant.int 32 - %int128_7766 = torch.constant.int 128 - %6346 = torch.prim.ListConstruct %int4_7764, %6285, %int32_7765, %int128_7766 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6347 = torch.aten._unsafe_view %6345, %6346 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6347, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_7767 = torch.constant.int -2 - %6348 = torch.aten.unsqueeze %6244, %int-2_7767 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6348, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int1_7764 = torch.constant.int 1 + %6371 = torch.aten.slice.Tensor %6368, %int0_7762, %int0_7763, %298, %int1_7764 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6371, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_7765 = torch.constant.int 1 + %int0_7766 = torch.constant.int 0 + %int9223372036854775807_7767 = torch.constant.int 9223372036854775807 %int1_7768 = torch.constant.int 1 - %6349 = torch.aten.size.int %6238, %int1_7768 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_7769 = torch.constant.int 4 - %int8_7770 = torch.constant.int 8 - %int4_7771 = torch.constant.int 4 - %int128_7772 = torch.constant.int 128 - %6350 = torch.prim.ListConstruct %int4_7769, %6349, %int8_7770, %int4_7771, %int128_7772 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7773 = torch.constant.bool false - %6351 = torch.aten.expand %6348, %6350, %false_7773 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6351, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_7774 = torch.constant.int 0 - %6352 = torch.aten.clone %6351, %int0_7774 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6352, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7775 = torch.constant.int 4 - %int32_7776 = torch.constant.int 32 - %int128_7777 = torch.constant.int 128 - %6353 = torch.prim.ListConstruct %int4_7775, %6349, %int32_7776, %int128_7777 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6354 = torch.aten._unsafe_view %6352, %6353 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6354, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %6372 = torch.aten.slice.Tensor %6371, %int1_7765, %int0_7766, %int9223372036854775807_7767, %int1_7768 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6372, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_7769 = torch.constant.int 0 + %6373 = torch.aten.unsqueeze %6370, %int0_7769 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6373, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_7770 = torch.constant.int 1 + %int0_7771 = torch.constant.int 0 + %int9223372036854775807_7772 = torch.constant.int 9223372036854775807 + %int1_7773 = torch.constant.int 1 + %6374 = torch.aten.slice.Tensor %6373, %int1_7770, %int0_7771, %int9223372036854775807_7772, %int1_7773 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6374, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_7774 = torch.constant.int 2 + %6375 = torch.aten.unsqueeze %6374, %int2_7774 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6375, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_7775 = torch.constant.int 3 + %int0_7776 = torch.constant.int 0 + %int9223372036854775807_7777 = torch.constant.int 9223372036854775807 %int1_7778 = torch.constant.int 1 - %int2_7779 = torch.constant.int 2 - %6355 = torch.aten.transpose.int %6272, %int1_7778, %int2_7779 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6355, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %6376 = torch.aten.slice.Tensor %6375, %int3_7775, %int0_7776, %int9223372036854775807_7777, %int1_7778 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6376, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_7779 = torch.constant.int 4 %int1_7780 = torch.constant.int 1 - %int2_7781 = torch.constant.int 2 - %6356 = torch.aten.transpose.int %6347, %int1_7780, %int2_7781 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6356, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_7781 = torch.constant.int 1 %int1_7782 = torch.constant.int 1 - %int2_7783 = torch.constant.int 2 - %6357 = torch.aten.transpose.int %6354, %int1_7782, %int2_7783 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6357, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_7784 = torch.constant.float 0.000000e+00 - %true_7785 = torch.constant.bool true - %none_7786 = torch.constant.none - %none_7787 = torch.constant.none - %6358:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6355, %6356, %6357, %float0.000000e00_7784, %true_7785, %none_7786, %none_7787) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %6358#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_7788 = torch.constant.int 1 - %int2_7789 = torch.constant.int 2 - %6359 = torch.aten.transpose.int %6358#0, %int1_7788, %int2_7789 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6359, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_7790 = torch.constant.int 4 - %int4096_7791 = torch.constant.int 4096 - %6360 = torch.prim.ListConstruct %int4_7790, %6257, %int4096_7791 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6361 = torch.aten.view %6359, %6360 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6361, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7792 = torch.constant.int -2 - %int-1_7793 = torch.constant.int -1 - %6362 = torch.aten.transpose.int %275, %int-2_7792, %int-1_7793 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7794 = torch.constant.int 4 - %6363 = torch.aten.mul.int %int4_7794, %6257 : !torch.int, !torch.int -> !torch.int - %int4096_7795 = torch.constant.int 4096 - %6364 = torch.prim.ListConstruct %6363, %int4096_7795 : (!torch.int, !torch.int) -> !torch.list - %6365 = torch.aten.view %6361, %6364 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6365, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6366 = torch.aten.mm %6365, %6362 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6366, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_7796 = torch.constant.int 4 - %int4096_7797 = torch.constant.int 4096 - %6367 = torch.prim.ListConstruct %int4_7796, %6257, %int4096_7797 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6368 = torch.aten.view %6366, %6367 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6368, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_7798 = torch.constant.int 1 - %6369 = torch.aten.add.Tensor %6207, %6368, %int1_7798 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6369, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_7799 = torch.constant.int 6 - %6370 = torch.prims.convert_element_type %6369, %int6_7799 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6370, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_7800 = torch.constant.int 2 - %6371 = torch.aten.pow.Tensor_Scalar %6370, %int2_7800 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6371, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_7801 = torch.constant.int -1 - %6372 = torch.prim.ListConstruct %int-1_7801 : (!torch.int) -> !torch.list - %true_7802 = torch.constant.bool true - %none_7803 = torch.constant.none - %6373 = torch.aten.mean.dim %6371, %6372, %true_7802, %none_7803 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6373, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_7804 = torch.constant.float 9.9999997473787516E-6 - %int1_7805 = torch.constant.int 1 - %6374 = torch.aten.add.Scalar %6373, %float9.999990e-06_7804, %int1_7805 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6374, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %6375 = torch.aten.rsqrt %6374 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6375, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %6376 = torch.aten.mul.Tensor %6370, %6375 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6376, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7806 = torch.constant.int 5 - %6377 = torch.prims.convert_element_type %6376, %int5_7806 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6377, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %6378 = torch.aten.mul.Tensor %276, %6377 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6378, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7807 = torch.constant.int 5 - %6379 = torch.prims.convert_element_type %6378, %int5_7807 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6379, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7808 = torch.constant.int -2 - %int-1_7809 = torch.constant.int -1 - %6380 = torch.aten.transpose.int %277, %int-2_7808, %int-1_7809 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_7810 = torch.constant.int 4 - %6381 = torch.aten.mul.int %int4_7810, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7811 = torch.constant.int 4096 - %6382 = torch.prim.ListConstruct %6381, %int4096_7811 : (!torch.int, !torch.int) -> !torch.list - %6383 = torch.aten.view %6379, %6382 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6383, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6384 = torch.aten.mm %6383, %6380 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %6384, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %6377 = torch.prim.ListConstruct %int4_7779, %int1_7780, %int1_7781, %int1_7782 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6378 = torch.aten.repeat %6376, %6377 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6378, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_7783 = torch.constant.int 0 + %6379 = torch.aten.unsqueeze %6372, %int0_7783 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6379, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_7784 = torch.constant.int 1 + %int0_7785 = torch.constant.int 0 + %int9223372036854775807_7786 = torch.constant.int 9223372036854775807 + %int1_7787 = torch.constant.int 1 + %6380 = torch.aten.slice.Tensor %6379, %int1_7784, %int0_7785, %int9223372036854775807_7786, %int1_7787 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6380, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_7788 = torch.constant.int 2 + %6381 = torch.aten.unsqueeze %6380, %int2_7788 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6381, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_7789 = torch.constant.int 3 + %int0_7790 = torch.constant.int 0 + %int9223372036854775807_7791 = torch.constant.int 9223372036854775807 + %int1_7792 = torch.constant.int 1 + %6382 = torch.aten.slice.Tensor %6381, %int3_7789, %int0_7790, %int9223372036854775807_7791, %int1_7792 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6382, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_7793 = torch.constant.int 4 + %int1_7794 = torch.constant.int 1 + %int1_7795 = torch.constant.int 1 + %int1_7796 = torch.constant.int 1 + %6383 = torch.prim.ListConstruct %int4_7793, %int1_7794, %int1_7795, %int1_7796 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6384 = torch.aten.repeat %6382, %6383 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6384, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %6385 = torch.aten.mul.Tensor %6264, %6378 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6385, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_7797 = torch.constant.int 3 + %int0_7798 = torch.constant.int 0 + %int64_7799 = torch.constant.int 64 + %int1_7800 = torch.constant.int 1 + %6386 = torch.aten.slice.Tensor %6264, %int3_7797, %int0_7798, %int64_7799, %int1_7800 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %6386, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_7801 = torch.constant.int 3 + %int64_7802 = torch.constant.int 64 + %int9223372036854775807_7803 = torch.constant.int 9223372036854775807 + %int1_7804 = torch.constant.int 1 + %6387 = torch.aten.slice.Tensor %6264, %int3_7801, %int64_7802, %int9223372036854775807_7803, %int1_7804 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %6387, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %6388 = torch.aten.neg %6387 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %6388, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %6389 = torch.prim.ListConstruct %6388, %6386 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_7805 = torch.constant.int -1 + %6390 = torch.aten.cat %6389, %int-1_7805 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6390, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %6391 = torch.aten.mul.Tensor %6390, %6384 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6391, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_7806 = torch.constant.int 1 + %6392 = torch.aten.add.Tensor %6385, %6391, %int1_7806 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6392, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_7807 = torch.constant.int 32 + %6393 = torch.aten.mul.Scalar %arg2, %int32_7807 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6393, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int22 = torch.constant.int 22 + %int1_7808 = torch.constant.int 1 + %6394 = torch.aten.add.Scalar %6393, %int22, %int1_7808 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6394, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_7809 = torch.constant.int 2 + %6395 = torch.aten.mul.Scalar %6394, %int2_7809 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6395, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_7810 = torch.constant.int 0 + %int1_7811 = torch.constant.int 1 + %6396 = torch.aten.add.Scalar %6395, %int0_7810, %int1_7811 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6396, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %6397 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %6398 = torch.aten.view %6396, %6397 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %6398, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> %int4_7812 = torch.constant.int 4 - %int14336_7813 = torch.constant.int 14336 - %6385 = torch.prim.ListConstruct %int4_7812, %306, %int14336_7813 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6386 = torch.aten.view %6384, %6385 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %6386, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %6387 = torch.aten.silu %6386 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %6387, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_7814 = torch.constant.int -2 - %int-1_7815 = torch.constant.int -1 - %6388 = torch.aten.transpose.int %278, %int-2_7814, %int-1_7815 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_7816 = torch.constant.int 4 - %6389 = torch.aten.mul.int %int4_7816, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7817 = torch.constant.int 4096 - %6390 = torch.prim.ListConstruct %6389, %int4096_7817 : (!torch.int, !torch.int) -> !torch.list - %6391 = torch.aten.view %6379, %6390 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6391, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6392 = torch.aten.mm %6391, %6388 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %6392, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_7818 = torch.constant.int 4 - %int14336_7819 = torch.constant.int 14336 - %6393 = torch.prim.ListConstruct %int4_7818, %306, %int14336_7819 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6394 = torch.aten.view %6392, %6393 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %6394, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %6395 = torch.aten.mul.Tensor %6387, %6394 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %6395, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_7820 = torch.constant.int -2 - %int-1_7821 = torch.constant.int -1 - %6396 = torch.aten.transpose.int %279, %int-2_7820, %int-1_7821 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_7822 = torch.constant.int 1 - %6397 = torch.aten.size.int %6386, %int1_7822 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_7823 = torch.constant.int 4 - %6398 = torch.aten.mul.int %int4_7823, %6397 : !torch.int, !torch.int -> !torch.int - %int14336_7824 = torch.constant.int 14336 - %6399 = torch.prim.ListConstruct %6398, %int14336_7824 : (!torch.int, !torch.int) -> !torch.list - %6400 = torch.aten.view %6395, %6399 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %6400, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %6401 = torch.aten.mm %6400, %6396 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6401, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_7825 = torch.constant.int 4 - %int4096_7826 = torch.constant.int 4096 - %6402 = torch.prim.ListConstruct %int4_7825, %6397, %int4096_7826 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6403 = torch.aten.view %6401, %6402 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6403, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_7827 = torch.constant.int 1 - %6404 = torch.aten.add.Tensor %6369, %6403, %int1_7827 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6404, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_7828 = torch.constant.int 6 - %6405 = torch.prims.convert_element_type %6404, %int6_7828 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6405, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_7829 = torch.constant.int 2 - %6406 = torch.aten.pow.Tensor_Scalar %6405, %int2_7829 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6406, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_7830 = torch.constant.int -1 - %6407 = torch.prim.ListConstruct %int-1_7830 : (!torch.int) -> !torch.list - %true_7831 = torch.constant.bool true - %none_7832 = torch.constant.none - %6408 = torch.aten.mean.dim %6406, %6407, %true_7831, %none_7832 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6408, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_7833 = torch.constant.float 9.9999997473787516E-6 - %int1_7834 = torch.constant.int 1 - %6409 = torch.aten.add.Scalar %6408, %float9.999990e-06_7833, %int1_7834 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6409, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %6410 = torch.aten.rsqrt %6409 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6410, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %6411 = torch.aten.mul.Tensor %6405, %6410 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6411, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7835 = torch.constant.int 5 - %6412 = torch.prims.convert_element_type %6411, %int5_7835 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6412, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %6413 = torch.aten.mul.Tensor %280, %6412 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6413, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_7836 = torch.constant.int 5 - %6414 = torch.prims.convert_element_type %6413, %int5_7836 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6414, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7837 = torch.constant.int -2 - %int-1_7838 = torch.constant.int -1 - %6415 = torch.aten.transpose.int %281, %int-2_7837, %int-1_7838 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7839 = torch.constant.int 4 - %6416 = torch.aten.mul.int %int4_7839, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7840 = torch.constant.int 4096 - %6417 = torch.prim.ListConstruct %6416, %int4096_7840 : (!torch.int, !torch.int) -> !torch.list - %6418 = torch.aten.view %6414, %6417 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6418, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6419 = torch.aten.mm %6418, %6415 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6419, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_7841 = torch.constant.int 4 - %int4096_7842 = torch.constant.int 4096 - %6420 = torch.prim.ListConstruct %int4_7841, %306, %int4096_7842 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6421 = torch.aten.view %6419, %6420 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6421, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_7843 = torch.constant.int -2 - %int-1_7844 = torch.constant.int -1 - %6422 = torch.aten.transpose.int %282, %int-2_7843, %int-1_7844 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_7845 = torch.constant.int 4 - %6423 = torch.aten.mul.int %int4_7845, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7846 = torch.constant.int 4096 - %6424 = torch.prim.ListConstruct %6423, %int4096_7846 : (!torch.int, !torch.int) -> !torch.list - %6425 = torch.aten.view %6414, %6424 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6425, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6426 = torch.aten.mm %6425, %6422 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %6426, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_7847 = torch.constant.int 4 - %int1024_7848 = torch.constant.int 1024 - %6427 = torch.prim.ListConstruct %int4_7847, %306, %int1024_7848 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6428 = torch.aten.view %6426, %6427 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %6428, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int-2_7849 = torch.constant.int -2 - %int-1_7850 = torch.constant.int -1 - %6429 = torch.aten.transpose.int %283, %int-2_7849, %int-1_7850 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int32_7813 = torch.constant.int 32 + %int8_7814 = torch.constant.int 8 + %int128_7815 = torch.constant.int 128 + %6399 = torch.prim.ListConstruct %int4_7812, %296, %int32_7813, %int8_7814, %int128_7815 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6400 = torch.aten.view %6392, %6399 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6400, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_7816 = torch.constant.int 32 + %int8_7817 = torch.constant.int 8 + %int128_7818 = torch.constant.int 128 + %6401 = torch.prim.ListConstruct %504, %int32_7816, %int8_7817, %int128_7818 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6402 = torch.aten.view %6400, %6401 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %6402, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_7819 = torch.constant.int 1 + %int2_7820 = torch.constant.int 2 + %6403 = torch.aten.transpose.int %6402, %int1_7819, %int2_7820 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6403, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_7821 = torch.constant.int 5 + %6404 = torch.prims.convert_element_type %6403, %int5_7821 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6404, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_7822 = torch.constant.int 32 + %int2_7823 = torch.constant.int 2 + %int8_7824 = torch.constant.int 8 + %int32_7825 = torch.constant.int 32 + %int128_7826 = torch.constant.int 128 + %6405 = torch.prim.ListConstruct %297, %int32_7822, %int2_7823, %int8_7824, %int32_7825, %int128_7826 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6406 = torch.aten.view %6168, %6405 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6406, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_7827 = torch.constant.int 8 + %int32_7828 = torch.constant.int 32 + %int128_7829 = torch.constant.int 128 + %6407 = torch.prim.ListConstruct %497, %int8_7827, %int32_7828, %int128_7829 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6408 = torch.aten.view %6406, %6407 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6408, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %6409 = torch.prim.ListConstruct %6398 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_7830 = torch.constant.bool false + %6410 = torch.aten.index_put %6408, %6409, %6404, %false_7830 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6410, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_7831 = torch.constant.int 32 + %int2_7832 = torch.constant.int 2 + %int8_7833 = torch.constant.int 8 + %int32_7834 = torch.constant.int 32 + %int128_7835 = torch.constant.int 128 + %6411 = torch.prim.ListConstruct %297, %int32_7831, %int2_7832, %int8_7833, %int32_7834, %int128_7835 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6412 = torch.aten.view %6410, %6411 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6412, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_7836 = torch.constant.int 2097152 + %6413 = torch.prim.ListConstruct %297, %int2097152_7836 : (!torch.int, !torch.int) -> !torch.list + %6414 = torch.aten.view %6412, %6413 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6414, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_7837 = torch.constant.int 32 + %int2_7838 = torch.constant.int 2 + %int8_7839 = torch.constant.int 8 + %int32_7840 = torch.constant.int 32 + %int128_7841 = torch.constant.int 128 + %6415 = torch.prim.ListConstruct %297, %int32_7837, %int2_7838, %int8_7839, %int32_7840, %int128_7841 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6416 = torch.aten.view %6414, %6415 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6416, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_7842 = torch.constant.int 8 + %int32_7843 = torch.constant.int 32 + %int128_7844 = torch.constant.int 128 + %6417 = torch.prim.ListConstruct %497, %int8_7842, %int32_7843, %int128_7844 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6418 = torch.aten.view %6416, %6417 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6418, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_7845 = torch.constant.int 32 + %6419 = torch.aten.mul.Scalar %arg2, %int32_7845 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6419, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int22_7846 = torch.constant.int 22 + %int1_7847 = torch.constant.int 1 + %6420 = torch.aten.add.Scalar %6419, %int22_7846, %int1_7847 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6420, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_7848 = torch.constant.int 2 + %6421 = torch.aten.mul.Scalar %6420, %int2_7848 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6421, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_7849 = torch.constant.int 1 + %int1_7850 = torch.constant.int 1 + %6422 = torch.aten.add.Scalar %6421, %int1_7849, %int1_7850 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6422, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %6423 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %6424 = torch.aten.view %6422, %6423 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %6424, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> %int4_7851 = torch.constant.int 4 - %6430 = torch.aten.mul.int %int4_7851, %306 : !torch.int, !torch.int -> !torch.int - %int4096_7852 = torch.constant.int 4096 - %6431 = torch.prim.ListConstruct %6430, %int4096_7852 : (!torch.int, !torch.int) -> !torch.list - %6432 = torch.aten.view %6414, %6431 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6432, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6433 = torch.aten.mm %6432, %6429 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> - torch.bind_symbolic_shape %6433, [%292], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> - %int4_7853 = torch.constant.int 4 - %int1024_7854 = torch.constant.int 1024 - %6434 = torch.prim.ListConstruct %int4_7853, %306, %int1024_7854 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6435 = torch.aten.view %6433, %6434 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> - torch.bind_symbolic_shape %6435, [%292], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> - %int4_7855 = torch.constant.int 4 - %int32_7856 = torch.constant.int 32 + %int32_7852 = torch.constant.int 32 + %int8_7853 = torch.constant.int 8 + %int128_7854 = torch.constant.int 128 + %6425 = torch.prim.ListConstruct %int4_7851, %296, %int32_7852, %int8_7853, %int128_7854 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6426 = torch.aten.view %6266, %6425 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6426, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_7855 = torch.constant.int 32 + %int8_7856 = torch.constant.int 8 %int128_7857 = torch.constant.int 128 - %6436 = torch.prim.ListConstruct %int4_7855, %306, %int32_7856, %int128_7857 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6437 = torch.aten.view %6421, %6436 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6437, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_7858 = torch.constant.int 4 - %int8_7859 = torch.constant.int 8 - %int128_7860 = torch.constant.int 128 - %6438 = torch.prim.ListConstruct %int4_7858, %306, %int8_7859, %int128_7860 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6439 = torch.aten.view %6428, %6438 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6439, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int4_7861 = torch.constant.int 4 - %int8_7862 = torch.constant.int 8 - %int128_7863 = torch.constant.int 128 - %6440 = torch.prim.ListConstruct %int4_7861, %306, %int8_7862, %int128_7863 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6441 = torch.aten.view %6435, %6440 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6441, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int131072_7864 = torch.constant.int 131072 - %none_7865 = torch.constant.none - %none_7866 = torch.constant.none - %cpu_7867 = torch.constant.device "cpu" - %false_7868 = torch.constant.bool false - %6442 = torch.aten.arange %int131072_7864, %none_7865, %none_7866, %cpu_7867, %false_7868 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_7869 = torch.constant.int 0 - %int128_7870 = torch.constant.int 128 - %none_7871 = torch.constant.none - %none_7872 = torch.constant.none - %cpu_7873 = torch.constant.device "cpu" - %false_7874 = torch.constant.bool false - %6443 = torch.aten.arange.start %int0_7869, %int128_7870, %none_7871, %none_7872, %cpu_7873, %false_7874 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_7875 = torch.constant.int 2 - %6444 = torch.aten.floor_divide.Scalar %6443, %int2_7875 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_7876 = torch.constant.int 6 - %6445 = torch.prims.convert_element_type %6444, %int6_7876 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> + %6427 = torch.prim.ListConstruct %504, %int32_7855, %int8_7856, %int128_7857 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6428 = torch.aten.view %6426, %6427 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %6428, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_7858 = torch.constant.int 1 + %int2_7859 = torch.constant.int 2 + %6429 = torch.aten.transpose.int %6428, %int1_7858, %int2_7859 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6429, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_7860 = torch.constant.int 5 + %6430 = torch.prims.convert_element_type %6429, %int5_7860 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6430, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %6431 = torch.prim.ListConstruct %6424 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_7861 = torch.constant.bool false + %6432 = torch.aten.index_put %6418, %6431, %6430, %false_7861 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6432, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_7862 = torch.constant.int 32 + %int2_7863 = torch.constant.int 2 + %int8_7864 = torch.constant.int 8 + %int32_7865 = torch.constant.int 32 + %int128_7866 = torch.constant.int 128 + %6433 = torch.prim.ListConstruct %297, %int32_7862, %int2_7863, %int8_7864, %int32_7865, %int128_7866 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6434 = torch.aten.view %6432, %6433 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6434, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_7867 = torch.constant.int 2097152 + %6435 = torch.prim.ListConstruct %297, %int2097152_7867 : (!torch.int, !torch.int) -> !torch.list + %6436 = torch.aten.view %6434, %6435 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6436, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_7868 = torch.constant.int -2 + %6437 = torch.aten.unsqueeze %6392, %int-2_7868 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6437, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_7869 = torch.constant.int 4 + %int8_7870 = torch.constant.int 8 + %int4_7871 = torch.constant.int 4 + %int128_7872 = torch.constant.int 128 + %6438 = torch.prim.ListConstruct %int4_7869, %298, %int8_7870, %int4_7871, %int128_7872 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_7873 = torch.constant.bool false + %6439 = torch.aten.expand %6437, %6438, %false_7873 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6439, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_7874 = torch.constant.int 0 + %6440 = torch.aten.clone %6439, %int0_7874 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6440, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_7875 = torch.constant.int 4 + %int32_7876 = torch.constant.int 32 %int128_7877 = torch.constant.int 128 - %6446 = torch.aten.div.Scalar %6445, %int128_7877 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_7878 = torch.constant.float 2.000000e+00 - %6447 = torch.aten.mul.Scalar %6446, %float2.000000e00_7878 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_7879 = torch.constant.float 5.000000e+05 - %6448 = torch.aten.pow.Scalar %float5.000000e05_7879, %6447 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %6449 = torch.aten.reciprocal %6448 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_7880 = torch.constant.float 1.000000e+00 - %6450 = torch.aten.mul.Scalar %6449, %float1.000000e00_7880 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_7881 = torch.constant.int 1 - %6451 = torch.aten.unsqueeze %6442, %int1_7881 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_7882 = torch.constant.int 0 - %6452 = torch.aten.unsqueeze %6450, %int0_7882 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %6453 = torch.aten.mul.Tensor %6451, %6452 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_7883 = torch.constant.int 1 - %6454 = torch.aten.size.int %6421, %int1_7883 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.int + %6441 = torch.prim.ListConstruct %int4_7875, %298, %int32_7876, %int128_7877 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6442 = torch.aten._unsafe_view %6440, %6441 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6442, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_7878 = torch.constant.int -2 + %6443 = torch.aten.unsqueeze %6266, %int-2_7878 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6443, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_7879 = torch.constant.int 4 + %int8_7880 = torch.constant.int 8 + %int4_7881 = torch.constant.int 4 + %int128_7882 = torch.constant.int 128 + %6444 = torch.prim.ListConstruct %int4_7879, %298, %int8_7880, %int4_7881, %int128_7882 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_7883 = torch.constant.bool false + %6445 = torch.aten.expand %6443, %6444, %false_7883 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6445, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int0_7884 = torch.constant.int 0 - %6455 = torch.aten.add.int %int0_7884, %6454 : !torch.int, !torch.int -> !torch.int - %int0_7885 = torch.constant.int 0 - %int0_7886 = torch.constant.int 0 - %int1_7887 = torch.constant.int 1 - %6456 = torch.aten.slice.Tensor %6453, %int0_7885, %int0_7886, %6455, %int1_7887 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6456, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %6446 = torch.aten.clone %6445, %int0_7884 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6446, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_7885 = torch.constant.int 4 + %int32_7886 = torch.constant.int 32 + %int128_7887 = torch.constant.int 128 + %6447 = torch.prim.ListConstruct %int4_7885, %298, %int32_7886, %int128_7887 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6448 = torch.aten._unsafe_view %6446, %6447 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6448, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int1_7888 = torch.constant.int 1 - %int0_7889 = torch.constant.int 0 - %int9223372036854775807_7890 = torch.constant.int 9223372036854775807 - %int1_7891 = torch.constant.int 1 - %6457 = torch.aten.slice.Tensor %6456, %int1_7888, %int0_7889, %int9223372036854775807_7890, %int1_7891 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6457, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> + %int2_7889 = torch.constant.int 2 + %6449 = torch.aten.transpose.int %6329, %int1_7888, %int2_7889 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6449, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_7890 = torch.constant.int 1 + %int2_7891 = torch.constant.int 2 + %6450 = torch.aten.transpose.int %6442, %int1_7890, %int2_7891 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6450, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_7892 = torch.constant.int 1 - %int0_7893 = torch.constant.int 0 - %int9223372036854775807_7894 = torch.constant.int 9223372036854775807 - %int1_7895 = torch.constant.int 1 - %6458 = torch.aten.slice.Tensor %6457, %int1_7892, %int0_7893, %int9223372036854775807_7894, %int1_7895 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6458, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_7896 = torch.constant.int 0 - %6459 = torch.aten.unsqueeze %6458, %int0_7896 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6459, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %int2_7893 = torch.constant.int 2 + %6451 = torch.aten.transpose.int %6448, %int1_7892, %int2_7893 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6451, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_7894 = torch.constant.float 0.000000e+00 + %false_7895 = torch.constant.bool false + %none_7896 = torch.constant.none + %6452:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6449, %6450, %6451, %float0.000000e00_7894, %false_7895, %327, %none_7896) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %6452#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_7897 = torch.constant.int 1 - %int0_7898 = torch.constant.int 0 - %int9223372036854775807_7899 = torch.constant.int 9223372036854775807 - %int1_7900 = torch.constant.int 1 - %6460 = torch.aten.slice.Tensor %6459, %int1_7897, %int0_7898, %int9223372036854775807_7899, %int1_7900 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6460, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_7901 = torch.constant.int 2 - %int0_7902 = torch.constant.int 0 - %int9223372036854775807_7903 = torch.constant.int 9223372036854775807 - %int1_7904 = torch.constant.int 1 - %6461 = torch.aten.slice.Tensor %6460, %int2_7901, %int0_7902, %int9223372036854775807_7903, %int1_7904 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6461, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> + %int2_7898 = torch.constant.int 2 + %6453 = torch.aten.transpose.int %6452#0, %int1_7897, %int2_7898 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6453, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_7899 = torch.constant.int 4 + %int4096_7900 = torch.constant.int 4096 + %6454 = torch.prim.ListConstruct %int4_7899, %298, %int4096_7900 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6455 = torch.aten.view %6453, %6454 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6455, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_7901 = torch.constant.int -2 + %int-1_7902 = torch.constant.int -1 + %6456 = torch.aten.transpose.int %204, %int-2_7901, %int-1_7902 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_7903 = torch.constant.int 5 + %6457 = torch.prims.convert_element_type %6456, %int5_7903 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_7904 = torch.constant.int 4096 + %6458 = torch.prim.ListConstruct %342, %int4096_7904 : (!torch.int, !torch.int) -> !torch.list + %6459 = torch.aten.view %6455, %6458 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6459, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6460 = torch.aten.mm %6459, %6457 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6460, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> %int4_7905 = torch.constant.int 4 - %int1_7906 = torch.constant.int 1 + %int4096_7906 = torch.constant.int 4096 + %6461 = torch.prim.ListConstruct %int4_7905, %298, %int4096_7906 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6462 = torch.aten.view %6460, %6461 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6462, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> %int1_7907 = torch.constant.int 1 - %6462 = torch.prim.ListConstruct %int4_7905, %int1_7906, %int1_7907 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6463 = torch.aten.repeat %6461, %6462 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %6463, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> + %6463 = torch.aten.add.Tensor %6229, %6462, %int1_7907 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6463, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> %int6_7908 = torch.constant.int 6 - %6464 = torch.prims.convert_element_type %6437, %int6_7908 : !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %6464, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %6465 = torch_c.to_builtin_tensor %6464 : !torch.vtensor<[4,?,32,128],f32> -> tensor<4x?x32x128xf32> - %6466 = torch_c.to_builtin_tensor %6463 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %6467 = util.call @sharktank_rotary_embedding_4_D_32_128_f32(%6465, %6466) : (tensor<4x?x32x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> - %6468 = torch_c.from_builtin_tensor %6467 : tensor<4x?x32x128xf32> -> !torch.vtensor<[4,?,32,128],f32> - torch.bind_symbolic_shape %6468, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f32> - %int5_7909 = torch.constant.int 5 - %6469 = torch.prims.convert_element_type %6468, %int5_7909 : !torch.vtensor<[4,?,32,128],f32>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6469, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int131072_7910 = torch.constant.int 131072 - %none_7911 = torch.constant.none + %6464 = torch.prims.convert_element_type %6463, %int6_7908 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6464, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_7909 = torch.constant.int 2 + %6465 = torch.aten.pow.Tensor_Scalar %6464, %int2_7909 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6465, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_7910 = torch.constant.int -1 + %6466 = torch.prim.ListConstruct %int-1_7910 : (!torch.int) -> !torch.list + %true_7911 = torch.constant.bool true %none_7912 = torch.constant.none - %cpu_7913 = torch.constant.device "cpu" - %false_7914 = torch.constant.bool false - %6470 = torch.aten.arange %int131072_7910, %none_7911, %none_7912, %cpu_7913, %false_7914 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> - %int0_7915 = torch.constant.int 0 - %int128_7916 = torch.constant.int 128 - %none_7917 = torch.constant.none - %none_7918 = torch.constant.none - %cpu_7919 = torch.constant.device "cpu" - %false_7920 = torch.constant.bool false - %6471 = torch.aten.arange.start %int0_7915, %int128_7916, %none_7917, %none_7918, %cpu_7919, %false_7920 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> - %int2_7921 = torch.constant.int 2 - %6472 = torch.aten.floor_divide.Scalar %6471, %int2_7921 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_7922 = torch.constant.int 6 - %6473 = torch.prims.convert_element_type %6472, %int6_7922 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_7923 = torch.constant.int 128 - %6474 = torch.aten.div.Scalar %6473, %int128_7923 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00_7924 = torch.constant.float 2.000000e+00 - %6475 = torch.aten.mul.Scalar %6474, %float2.000000e00_7924 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %float5.000000e05_7925 = torch.constant.float 5.000000e+05 - %6476 = torch.aten.pow.Scalar %float5.000000e05_7925, %6475 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %6477 = torch.aten.reciprocal %6476 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %float1.000000e00_7926 = torch.constant.float 1.000000e+00 - %6478 = torch.aten.mul.Scalar %6477, %float1.000000e00_7926 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> - %int1_7927 = torch.constant.int 1 - %6479 = torch.aten.unsqueeze %6470, %int1_7927 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_7928 = torch.constant.int 0 - %6480 = torch.aten.unsqueeze %6478, %int0_7928 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %6481 = torch.aten.mul.Tensor %6479, %6480 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int1_7929 = torch.constant.int 1 - %6482 = torch.aten.size.int %6428, %int1_7929 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int0_7930 = torch.constant.int 0 - %6483 = torch.aten.add.int %int0_7930, %6482 : !torch.int, !torch.int -> !torch.int - %int0_7931 = torch.constant.int 0 - %int0_7932 = torch.constant.int 0 - %int1_7933 = torch.constant.int 1 - %6484 = torch.aten.slice.Tensor %6481, %int0_7931, %int0_7932, %6483, %int1_7933 : !torch.vtensor<[131072,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6484, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_7934 = torch.constant.int 1 - %int0_7935 = torch.constant.int 0 - %int9223372036854775807_7936 = torch.constant.int 9223372036854775807 - %int1_7937 = torch.constant.int 1 - %6485 = torch.aten.slice.Tensor %6484, %int1_7934, %int0_7935, %int9223372036854775807_7936, %int1_7937 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6485, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int1_7938 = torch.constant.int 1 - %int0_7939 = torch.constant.int 0 - %int9223372036854775807_7940 = torch.constant.int 9223372036854775807 - %int1_7941 = torch.constant.int 1 - %6486 = torch.aten.slice.Tensor %6485, %int1_7938, %int0_7939, %int9223372036854775807_7940, %int1_7941 : !torch.vtensor<[?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f32> - torch.bind_symbolic_shape %6486, [%292], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f32> - %int0_7942 = torch.constant.int 0 - %6487 = torch.aten.unsqueeze %6486, %int0_7942 : !torch.vtensor<[?,128],f32>, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6487, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int1_7943 = torch.constant.int 1 - %int0_7944 = torch.constant.int 0 - %int9223372036854775807_7945 = torch.constant.int 9223372036854775807 - %int1_7946 = torch.constant.int 1 - %6488 = torch.aten.slice.Tensor %6487, %int1_7943, %int0_7944, %int9223372036854775807_7945, %int1_7946 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6488, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int2_7947 = torch.constant.int 2 - %int0_7948 = torch.constant.int 0 - %int9223372036854775807_7949 = torch.constant.int 9223372036854775807 - %int1_7950 = torch.constant.int 1 - %6489 = torch.aten.slice.Tensor %6488, %int2_7947, %int0_7948, %int9223372036854775807_7949, %int1_7950 : !torch.vtensor<[1,?,128],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f32> - torch.bind_symbolic_shape %6489, [%292], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f32> - %int4_7951 = torch.constant.int 4 - %int1_7952 = torch.constant.int 1 - %int1_7953 = torch.constant.int 1 - %6490 = torch.prim.ListConstruct %int4_7951, %int1_7952, %int1_7953 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6491 = torch.aten.repeat %6489, %6490 : !torch.vtensor<[1,?,128],f32>, !torch.list -> !torch.vtensor<[4,?,128],f32> - torch.bind_symbolic_shape %6491, [%292], affine_map<()[s0] -> (4, s0 * 32, 128)> : !torch.vtensor<[4,?,128],f32> - %int6_7954 = torch.constant.int 6 - %6492 = torch.prims.convert_element_type %6439, %int6_7954 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %6492, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %6493 = torch_c.to_builtin_tensor %6492 : !torch.vtensor<[4,?,8,128],f32> -> tensor<4x?x8x128xf32> - %6494 = torch_c.to_builtin_tensor %6491 : !torch.vtensor<[4,?,128],f32> -> tensor<4x?x128xf32> - %6495 = util.call @sharktank_rotary_embedding_4_D_8_128_f32(%6493, %6494) : (tensor<4x?x8x128xf32>, tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> - %6496 = torch_c.from_builtin_tensor %6495 : tensor<4x?x8x128xf32> -> !torch.vtensor<[4,?,8,128],f32> - torch.bind_symbolic_shape %6496, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f32> - %int5_7955 = torch.constant.int 5 - %6497 = torch.prims.convert_element_type %6496, %int5_7955 : !torch.vtensor<[4,?,8,128],f32>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6497, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int64_7956 = torch.constant.int 64 - %6498 = torch.aten.mul.Scalar %arg2, %int64_7956 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6498, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int62 = torch.constant.int 62 - %int1_7957 = torch.constant.int 1 - %6499 = torch.aten.add.Scalar %6498, %int62, %int1_7957 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6499, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_7958 = torch.constant.int 4 - %int32_7959 = torch.constant.int 32 - %int8_7960 = torch.constant.int 8 - %int128_7961 = torch.constant.int 128 - %6500 = torch.prim.ListConstruct %int4_7958, %398, %int32_7959, %int8_7960, %int128_7961 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6501 = torch.aten.view %6497, %6500 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6501, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_7962 = torch.constant.int 4 - %6502 = torch.aten.mul.int %int4_7962, %398 : !torch.int, !torch.int -> !torch.int - %int32_7963 = torch.constant.int 32 - %int8_7964 = torch.constant.int 8 + %6467 = torch.aten.mean.dim %6465, %6466, %true_7911, %none_7912 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6467, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_7913 = torch.constant.float 9.9999997473787516E-6 + %int1_7914 = torch.constant.int 1 + %6468 = torch.aten.add.Scalar %6467, %float9.999990e-06_7913, %int1_7914 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6468, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %6469 = torch.aten.rsqrt %6468 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6469, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %6470 = torch.aten.mul.Tensor %6464, %6469 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6470, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_7915 = torch.constant.int 5 + %6471 = torch.prims.convert_element_type %6470, %int5_7915 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6471, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %6472 = torch.aten.mul.Tensor %205, %6471 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6472, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_7916 = torch.constant.int 5 + %6473 = torch.prims.convert_element_type %6472, %int5_7916 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6473, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_7917 = torch.constant.int -2 + %int-1_7918 = torch.constant.int -1 + %6474 = torch.aten.transpose.int %206, %int-2_7917, %int-1_7918 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_7919 = torch.constant.int 5 + %6475 = torch.prims.convert_element_type %6474, %int5_7919 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_7920 = torch.constant.int 4096 + %6476 = torch.prim.ListConstruct %342, %int4096_7920 : (!torch.int, !torch.int) -> !torch.list + %6477 = torch.aten.view %6473, %6476 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6477, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6478 = torch.aten.mm %6477, %6475 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %6478, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_7921 = torch.constant.int 4 + %int14336_7922 = torch.constant.int 14336 + %6479 = torch.prim.ListConstruct %int4_7921, %298, %int14336_7922 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6480 = torch.aten.view %6478, %6479 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %6480, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %6481 = torch.aten.silu %6480 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %6481, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_7923 = torch.constant.int -2 + %int-1_7924 = torch.constant.int -1 + %6482 = torch.aten.transpose.int %207, %int-2_7923, %int-1_7924 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_7925 = torch.constant.int 5 + %6483 = torch.prims.convert_element_type %6482, %int5_7925 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_7926 = torch.constant.int 4096 + %6484 = torch.prim.ListConstruct %342, %int4096_7926 : (!torch.int, !torch.int) -> !torch.list + %6485 = torch.aten.view %6473, %6484 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6485, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6486 = torch.aten.mm %6485, %6483 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %6486, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_7927 = torch.constant.int 4 + %int14336_7928 = torch.constant.int 14336 + %6487 = torch.prim.ListConstruct %int4_7927, %298, %int14336_7928 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6488 = torch.aten.view %6486, %6487 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %6488, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %6489 = torch.aten.mul.Tensor %6481, %6488 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %6489, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_7929 = torch.constant.int -2 + %int-1_7930 = torch.constant.int -1 + %6490 = torch.aten.transpose.int %208, %int-2_7929, %int-1_7930 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_7931 = torch.constant.int 5 + %6491 = torch.prims.convert_element_type %6490, %int5_7931 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_7932 = torch.constant.int 14336 + %6492 = torch.prim.ListConstruct %342, %int14336_7932 : (!torch.int, !torch.int) -> !torch.list + %6493 = torch.aten.view %6489, %6492 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %6493, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %6494 = torch.aten.mm %6493, %6491 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6494, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_7933 = torch.constant.int 4 + %int4096_7934 = torch.constant.int 4096 + %6495 = torch.prim.ListConstruct %int4_7933, %298, %int4096_7934 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6496 = torch.aten.view %6494, %6495 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6496, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_7935 = torch.constant.int 1 + %6497 = torch.aten.add.Tensor %6463, %6496, %int1_7935 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6497, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_7936 = torch.constant.int 6 + %6498 = torch.prims.convert_element_type %6497, %int6_7936 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6498, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_7937 = torch.constant.int 2 + %6499 = torch.aten.pow.Tensor_Scalar %6498, %int2_7937 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6499, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_7938 = torch.constant.int -1 + %6500 = torch.prim.ListConstruct %int-1_7938 : (!torch.int) -> !torch.list + %true_7939 = torch.constant.bool true + %none_7940 = torch.constant.none + %6501 = torch.aten.mean.dim %6499, %6500, %true_7939, %none_7940 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6501, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_7941 = torch.constant.float 9.9999997473787516E-6 + %int1_7942 = torch.constant.int 1 + %6502 = torch.aten.add.Scalar %6501, %float9.999990e-06_7941, %int1_7942 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6502, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %6503 = torch.aten.rsqrt %6502 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6503, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %6504 = torch.aten.mul.Tensor %6498, %6503 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6504, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_7943 = torch.constant.int 5 + %6505 = torch.prims.convert_element_type %6504, %int5_7943 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6505, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %6506 = torch.aten.mul.Tensor %209, %6505 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6506, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_7944 = torch.constant.int 5 + %6507 = torch.prims.convert_element_type %6506, %int5_7944 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6507, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_7945 = torch.constant.int -2 + %int-1_7946 = torch.constant.int -1 + %6508 = torch.aten.transpose.int %210, %int-2_7945, %int-1_7946 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_7947 = torch.constant.int 5 + %6509 = torch.prims.convert_element_type %6508, %int5_7947 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_7948 = torch.constant.int 4096 + %6510 = torch.prim.ListConstruct %342, %int4096_7948 : (!torch.int, !torch.int) -> !torch.list + %6511 = torch.aten.view %6507, %6510 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6511, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6512 = torch.aten.mm %6511, %6509 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6512, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_7949 = torch.constant.int 4 + %int4096_7950 = torch.constant.int 4096 + %6513 = torch.prim.ListConstruct %int4_7949, %298, %int4096_7950 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6514 = torch.aten.view %6512, %6513 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6514, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_7951 = torch.constant.int -2 + %int-1_7952 = torch.constant.int -1 + %6515 = torch.aten.transpose.int %211, %int-2_7951, %int-1_7952 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_7953 = torch.constant.int 5 + %6516 = torch.prims.convert_element_type %6515, %int5_7953 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_7954 = torch.constant.int 4096 + %6517 = torch.prim.ListConstruct %342, %int4096_7954 : (!torch.int, !torch.int) -> !torch.list + %6518 = torch.aten.view %6507, %6517 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6518, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6519 = torch.aten.mm %6518, %6516 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %6519, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_7955 = torch.constant.int 4 + %int1024_7956 = torch.constant.int 1024 + %6520 = torch.prim.ListConstruct %int4_7955, %298, %int1024_7956 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6521 = torch.aten.view %6519, %6520 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %6521, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_7957 = torch.constant.int -2 + %int-1_7958 = torch.constant.int -1 + %6522 = torch.aten.transpose.int %212, %int-2_7957, %int-1_7958 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_7959 = torch.constant.int 5 + %6523 = torch.prims.convert_element_type %6522, %int5_7959 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_7960 = torch.constant.int 4096 + %6524 = torch.prim.ListConstruct %342, %int4096_7960 : (!torch.int, !torch.int) -> !torch.list + %6525 = torch.aten.view %6507, %6524 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6525, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6526 = torch.aten.mm %6525, %6523 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %6526, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_7961 = torch.constant.int 4 + %int1024_7962 = torch.constant.int 1024 + %6527 = torch.prim.ListConstruct %int4_7961, %298, %int1024_7962 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6528 = torch.aten.view %6526, %6527 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %6528, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_7963 = torch.constant.int 4 + %int32_7964 = torch.constant.int 32 %int128_7965 = torch.constant.int 128 - %6503 = torch.prim.ListConstruct %6502, %int32_7963, %int8_7964, %int128_7965 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6504 = torch.aten.view %6501, %6503 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6504, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %6529 = torch.prim.ListConstruct %int4_7963, %298, %int32_7964, %int128_7965 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6530 = torch.aten.view %6514, %6529 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6530, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int4_7966 = torch.constant.int 4 - %6505 = torch.aten.mul.int %int4_7966, %398 : !torch.int, !torch.int -> !torch.int - %6506 = torch.prim.ListConstruct %6505 : (!torch.int) -> !torch.list - %6507 = torch.aten.view %6499, %6506 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %6507, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_7967 = torch.constant.int 32 - %int2_7968 = torch.constant.int 2 - %int32_7969 = torch.constant.int 32 + %int8_7967 = torch.constant.int 8 + %int128_7968 = torch.constant.int 128 + %6531 = torch.prim.ListConstruct %int4_7966, %298, %int8_7967, %int128_7968 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6532 = torch.aten.view %6521, %6531 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6532, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_7969 = torch.constant.int 4 %int8_7970 = torch.constant.int 8 %int128_7971 = torch.constant.int 128 - %6508 = torch.prim.ListConstruct %389, %int32_7967, %int2_7968, %int32_7969, %int8_7970, %int128_7971 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6509 = torch.aten.view %6341, %6508 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6509, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_7972 = torch.constant.int 32 - %6510 = torch.aten.mul.int %389, %int32_7972 : !torch.int, !torch.int -> !torch.int - %int2_7973 = torch.constant.int 2 - %6511 = torch.aten.mul.int %6510, %int2_7973 : !torch.int, !torch.int -> !torch.int - %int32_7974 = torch.constant.int 32 - %int8_7975 = torch.constant.int 8 - %int128_7976 = torch.constant.int 128 - %6512 = torch.prim.ListConstruct %6511, %int32_7974, %int8_7975, %int128_7976 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6513 = torch.aten.view %6509, %6512 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6513, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %6514 = torch.prim.ListConstruct %6507 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_7977 = torch.constant.bool false - %6515 = torch.aten.index_put %6513, %6514, %6504, %false_7977 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6515, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_7978 = torch.constant.int 32 + %6533 = torch.prim.ListConstruct %int4_7969, %298, %int8_7970, %int128_7971 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6534 = torch.aten.view %6528, %6533 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6534, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_7972 = torch.constant.int 131072 + %none_7973 = torch.constant.none + %none_7974 = torch.constant.none + %cpu_7975 = torch.constant.device "cpu" + %false_7976 = torch.constant.bool false + %6535 = torch.aten.arange %int131072_7972, %none_7973, %none_7974, %cpu_7975, %false_7976 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_7977 = torch.constant.int 0 + %int128_7978 = torch.constant.int 128 %int2_7979 = torch.constant.int 2 - %int32_7980 = torch.constant.int 32 - %int8_7981 = torch.constant.int 8 - %int128_7982 = torch.constant.int 128 - %6516 = torch.prim.ListConstruct %389, %int32_7978, %int2_7979, %int32_7980, %int8_7981, %int128_7982 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6517 = torch.aten.view %6515, %6516 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6517, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_7983 = torch.constant.int 2097152 - %6518 = torch.prim.ListConstruct %389, %int2097152_7983 : (!torch.int, !torch.int) -> !torch.list - %6519 = torch.aten.view %6517, %6518 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %6519, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_7984 = torch.constant.int 32 - %int2_7985 = torch.constant.int 2 - %int32_7986 = torch.constant.int 32 - %int8_7987 = torch.constant.int 8 - %int128_7988 = torch.constant.int 128 - %6520 = torch.prim.ListConstruct %389, %int32_7984, %int2_7985, %int32_7986, %int8_7987, %int128_7988 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6521 = torch.aten.view %6519, %6520 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6521, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_7989 = torch.constant.int 32 + %int4_7980 = torch.constant.int 4 + %none_7981 = torch.constant.none + %cpu_7982 = torch.constant.device "cpu" + %false_7983 = torch.constant.bool false + %6536 = torch.aten.arange.start_step %int0_7977, %int128_7978, %int2_7979, %int4_7980, %none_7981, %cpu_7982, %false_7983 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_7984 = torch.constant.int 6 + %6537 = torch.prims.convert_element_type %6536, %int6_7984 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_7985 = torch.constant.int 128 + %6538 = torch.aten.div.Scalar %6537, %int128_7985 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_7986 = torch.constant.float 5.000000e+05 + %6539 = torch.aten.pow.Scalar %float5.000000e05_7986, %6538 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6540 = torch.aten.reciprocal %6539 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_7987 = torch.constant.float 1.000000e+00 + %6541 = torch.aten.mul.Scalar %6540, %float1.000000e00_7987 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %6542 = torch.aten.reciprocal %6541 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_7988 = torch.constant.float 6.2831853071795862 + %6543 = torch.aten.mul.Scalar %6542, %float6.283190e00_7988 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_7989 = torch.constant.float 8.192000e+03 + %6544 = torch.aten.gt.Scalar %6543, %float8.192000e03_7989 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> %int8_7990 = torch.constant.int 8 - %int128_7991 = torch.constant.int 128 - %6522 = torch.prim.ListConstruct %6511, %int32_7989, %int8_7990, %int128_7991 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6523 = torch.aten.view %6521, %6522 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6523, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int4_7992 = torch.constant.int 4 - %int32_7993 = torch.constant.int 32 - %int8_7994 = torch.constant.int 8 - %int128_7995 = torch.constant.int 128 - %6524 = torch.prim.ListConstruct %int4_7992, %398, %int32_7993, %int8_7994, %int128_7995 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6525 = torch.aten.view %6441, %6524 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6525, [%292], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int4_7996 = torch.constant.int 4 - %6526 = torch.aten.mul.int %int4_7996, %398 : !torch.int, !torch.int -> !torch.int - %int32_7997 = torch.constant.int 32 - %int8_7998 = torch.constant.int 8 - %int128_7999 = torch.constant.int 128 - %6527 = torch.prim.ListConstruct %6526, %int32_7997, %int8_7998, %int128_7999 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6528 = torch.aten.view %6525, %6527 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6528, [%292], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int1_8000 = torch.constant.int 1 - %int1_8001 = torch.constant.int 1 - %6529 = torch.aten.add.Scalar %6499, %int1_8000, %int1_8001 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6529, [%292], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_8002 = torch.constant.int 4 - %6530 = torch.aten.mul.int %int4_8002, %398 : !torch.int, !torch.int -> !torch.int - %6531 = torch.prim.ListConstruct %6530 : (!torch.int) -> !torch.list - %6532 = torch.aten.view %6529, %6531 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %6532, [%292], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %6533 = torch.prim.ListConstruct %6532 : (!torch.vtensor<[?],si64>) -> !torch.list> - %false_8003 = torch.constant.bool false - %6534 = torch.aten.index_put %6523, %6533, %6528, %false_8003 : !torch.vtensor<[?,32,8,128],f16>, !torch.list>, !torch.vtensor<[?,32,8,128],f16>, !torch.bool -> !torch.vtensor<[?,32,8,128],f16> - torch.bind_symbolic_shape %6534, [%293], affine_map<()[s0] -> (s0 * 64, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> - %int32_8004 = torch.constant.int 32 - %int2_8005 = torch.constant.int 2 - %int32_8006 = torch.constant.int 32 - %int8_8007 = torch.constant.int 8 - %int128_8008 = torch.constant.int 128 - %6535 = torch.prim.ListConstruct %389, %int32_8004, %int2_8005, %int32_8006, %int8_8007, %int128_8008 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6536 = torch.aten.view %6534, %6535 : !torch.vtensor<[?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6536, [%293], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_8009 = torch.constant.int 2097152 - %6537 = torch.prim.ListConstruct %389, %int2097152_8009 : (!torch.int, !torch.int) -> !torch.list - %6538 = torch.aten.view %6536, %6537 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.overwrite.tensor.contents %6538 overwrites %arg3 : !torch.vtensor<[?,2097152],f16>, !torch.tensor<[?,2097152],f16> - torch.bind_symbolic_shape %6538, [%293], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int-2_8010 = torch.constant.int -2 - %6539 = torch.aten.unsqueeze %6497, %int-2_8010 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6539, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int4_8011 = torch.constant.int 4 - %int8_8012 = torch.constant.int 8 - %int4_8013 = torch.constant.int 4 - %int128_8014 = torch.constant.int 128 - %6540 = torch.prim.ListConstruct %int4_8011, %6482, %int8_8012, %int4_8013, %int128_8014 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_8015 = torch.constant.bool false - %6541 = torch.aten.expand %6539, %6540, %false_8015 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6541, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %6545 = torch.aten.div.Scalar %6541, %int8_7990 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6546 = torch.aten.where.self %6544, %6545, %6541 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6547 = torch.aten.reciprocal %6543 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_7991 = torch.constant.int 8192 + %6548 = torch.aten.mul.Scalar %6547, %int8192_7991 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_7992 = torch.constant.int 1 + %int1_7993 = torch.constant.int 1 + %6549 = torch.aten.sub.Scalar %6548, %int1_7992, %int1_7993 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_7994 = torch.constant.int 3 + %6550 = torch.aten.div.Scalar %6549, %int3_7994 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_7995 = torch.constant.int 1 + %int1_7996 = torch.constant.int 1 + %6551 = torch.aten.rsub.Scalar %6550, %int1_7995, %int1_7996 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %6552 = torch.aten.mul.Tensor %6551, %6546 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_7997 = torch.constant.int 8 + %6553 = torch.aten.div.Scalar %6552, %int8_7997 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6554 = torch.aten.mul.Tensor %6550, %6546 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_7998 = torch.constant.int 1 + %6555 = torch.aten.add.Tensor %6553, %6554, %int1_7998 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_7999 = torch.constant.float 2.048000e+03 + %6556 = torch.aten.lt.Scalar %6543, %float2.048000e03_7999 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6557 = torch.aten.bitwise_not %6556 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_8000 = torch.constant.float 8.192000e+03 + %6558 = torch.aten.gt.Scalar %6543, %float8.192000e03_8000 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6559 = torch.aten.bitwise_not %6558 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6560 = torch.aten.mul.Tensor %6557, %6559 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6561 = torch.aten.where.self %6560, %6555, %6546 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6562 = torch.prim.ListConstruct %6561, %6561 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_8001 = torch.constant.int -1 + %6563 = torch.aten.cat %6562, %int-1_8001 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_8002 = torch.constant.int 6 + %6564 = torch.prims.convert_element_type %6563, %int6_8002 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_8003 = torch.constant.int 1 + %6565 = torch.aten.unsqueeze %6535, %int1_8003 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_8004 = torch.constant.int 6 + %6566 = torch.prims.convert_element_type %6565, %int6_8004 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_8005 = torch.constant.int 0 + %6567 = torch.aten.unsqueeze %6564, %int0_8005 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_8006 = torch.constant.int 6 + %6568 = torch.prims.convert_element_type %6567, %int6_8006 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %6569 = torch.aten.mul.Tensor %6566, %6568 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %6570 = torch.aten.cos %6569 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_8007 = torch.constant.int 5 + %6571 = torch.prims.convert_element_type %6570, %int5_8007 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %6572 = torch.aten.sin %6569 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_8008 = torch.constant.int 5 + %6573 = torch.prims.convert_element_type %6572, %int5_8008 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_8009 = torch.constant.int 0 + %int0_8010 = torch.constant.int 0 + %int1_8011 = torch.constant.int 1 + %6574 = torch.aten.slice.Tensor %6571, %int0_8009, %int0_8010, %298, %int1_8011 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6574, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_8012 = torch.constant.int 1 + %int0_8013 = torch.constant.int 0 + %int9223372036854775807_8014 = torch.constant.int 9223372036854775807 + %int1_8015 = torch.constant.int 1 + %6575 = torch.aten.slice.Tensor %6574, %int1_8012, %int0_8013, %int9223372036854775807_8014, %int1_8015 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6575, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> %int0_8016 = torch.constant.int 0 - %6542 = torch.aten.clone %6541, %int0_8016 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6542, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_8017 = torch.constant.int 4 - %int32_8018 = torch.constant.int 32 - %int128_8019 = torch.constant.int 128 - %6543 = torch.prim.ListConstruct %int4_8017, %6482, %int32_8018, %int128_8019 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6544 = torch.aten._unsafe_view %6542, %6543 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6544, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_8020 = torch.constant.int -2 - %6545 = torch.aten.unsqueeze %6441, %int-2_8020 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6545, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_8021 = torch.constant.int 1 - %6546 = torch.aten.size.int %6435, %int1_8021 : !torch.vtensor<[4,?,1024],f16>, !torch.int -> !torch.int - %int4_8022 = torch.constant.int 4 - %int8_8023 = torch.constant.int 8 - %int4_8024 = torch.constant.int 4 - %int128_8025 = torch.constant.int 128 - %6547 = torch.prim.ListConstruct %int4_8022, %6546, %int8_8023, %int4_8024, %int128_8025 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_8026 = torch.constant.bool false - %6548 = torch.aten.expand %6545, %6547, %false_8026 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6548, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_8027 = torch.constant.int 0 - %6549 = torch.aten.clone %6548, %int0_8027 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6549, [%292], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_8028 = torch.constant.int 4 - %int32_8029 = torch.constant.int 32 - %int128_8030 = torch.constant.int 128 - %6550 = torch.prim.ListConstruct %int4_8028, %6546, %int32_8029, %int128_8030 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6551 = torch.aten._unsafe_view %6549, %6550 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6551, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_8031 = torch.constant.int 1 - %int2_8032 = torch.constant.int 2 - %6552 = torch.aten.transpose.int %6469, %int1_8031, %int2_8032 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6552, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_8033 = torch.constant.int 1 - %int2_8034 = torch.constant.int 2 - %6553 = torch.aten.transpose.int %6544, %int1_8033, %int2_8034 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6553, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int0_8017 = torch.constant.int 0 + %int1_8018 = torch.constant.int 1 + %6576 = torch.aten.slice.Tensor %6573, %int0_8016, %int0_8017, %298, %int1_8018 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6576, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_8019 = torch.constant.int 1 + %int0_8020 = torch.constant.int 0 + %int9223372036854775807_8021 = torch.constant.int 9223372036854775807 + %int1_8022 = torch.constant.int 1 + %6577 = torch.aten.slice.Tensor %6576, %int1_8019, %int0_8020, %int9223372036854775807_8021, %int1_8022 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6577, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_8023 = torch.constant.int 0 + %6578 = torch.aten.unsqueeze %6575, %int0_8023 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6578, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_8024 = torch.constant.int 1 + %int0_8025 = torch.constant.int 0 + %int9223372036854775807_8026 = torch.constant.int 9223372036854775807 + %int1_8027 = torch.constant.int 1 + %6579 = torch.aten.slice.Tensor %6578, %int1_8024, %int0_8025, %int9223372036854775807_8026, %int1_8027 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6579, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_8028 = torch.constant.int 2 + %6580 = torch.aten.unsqueeze %6579, %int2_8028 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6580, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_8029 = torch.constant.int 3 + %int0_8030 = torch.constant.int 0 + %int9223372036854775807_8031 = torch.constant.int 9223372036854775807 + %int1_8032 = torch.constant.int 1 + %6581 = torch.aten.slice.Tensor %6580, %int3_8029, %int0_8030, %int9223372036854775807_8031, %int1_8032 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6581, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_8033 = torch.constant.int 4 + %int1_8034 = torch.constant.int 1 %int1_8035 = torch.constant.int 1 - %int2_8036 = torch.constant.int 2 - %6554 = torch.aten.transpose.int %6551, %int1_8035, %int2_8036 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6554, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_8037 = torch.constant.float 0.000000e+00 - %true_8038 = torch.constant.bool true - %none_8039 = torch.constant.none - %none_8040 = torch.constant.none - %6555:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6552, %6553, %6554, %float0.000000e00_8037, %true_8038, %none_8039, %none_8040) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) - torch.bind_symbolic_shape %6555#0, [%292], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_8036 = torch.constant.int 1 + %6582 = torch.prim.ListConstruct %int4_8033, %int1_8034, %int1_8035, %int1_8036 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6583 = torch.aten.repeat %6581, %6582 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6583, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_8037 = torch.constant.int 0 + %6584 = torch.aten.unsqueeze %6577, %int0_8037 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6584, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_8038 = torch.constant.int 1 + %int0_8039 = torch.constant.int 0 + %int9223372036854775807_8040 = torch.constant.int 9223372036854775807 %int1_8041 = torch.constant.int 1 + %6585 = torch.aten.slice.Tensor %6584, %int1_8038, %int0_8039, %int9223372036854775807_8040, %int1_8041 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6585, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> %int2_8042 = torch.constant.int 2 - %6556 = torch.aten.transpose.int %6555#0, %int1_8041, %int2_8042 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6556, [%292], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int4_8043 = torch.constant.int 4 - %int4096_8044 = torch.constant.int 4096 - %6557 = torch.prim.ListConstruct %int4_8043, %6454, %int4096_8044 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6558 = torch.aten.view %6556, %6557 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6558, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_8045 = torch.constant.int -2 - %int-1_8046 = torch.constant.int -1 - %6559 = torch.aten.transpose.int %284, %int-2_8045, %int-1_8046 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %6586 = torch.aten.unsqueeze %6585, %int2_8042 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6586, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_8043 = torch.constant.int 3 + %int0_8044 = torch.constant.int 0 + %int9223372036854775807_8045 = torch.constant.int 9223372036854775807 + %int1_8046 = torch.constant.int 1 + %6587 = torch.aten.slice.Tensor %6586, %int3_8043, %int0_8044, %int9223372036854775807_8045, %int1_8046 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6587, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> %int4_8047 = torch.constant.int 4 - %6560 = torch.aten.mul.int %int4_8047, %6454 : !torch.int, !torch.int -> !torch.int - %int4096_8048 = torch.constant.int 4096 - %6561 = torch.prim.ListConstruct %6560, %int4096_8048 : (!torch.int, !torch.int) -> !torch.list - %6562 = torch.aten.view %6558, %6561 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6562, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6563 = torch.aten.mm %6562, %6559 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6563, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_8049 = torch.constant.int 4 - %int4096_8050 = torch.constant.int 4096 - %6564 = torch.prim.ListConstruct %int4_8049, %6454, %int4096_8050 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6565 = torch.aten.view %6563, %6564 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6565, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_8051 = torch.constant.int 1 - %6566 = torch.aten.add.Tensor %6404, %6565, %int1_8051 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6566, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_8052 = torch.constant.int 6 - %6567 = torch.prims.convert_element_type %6566, %int6_8052 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6567, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_8053 = torch.constant.int 2 - %6568 = torch.aten.pow.Tensor_Scalar %6567, %int2_8053 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6568, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_8054 = torch.constant.int -1 - %6569 = torch.prim.ListConstruct %int-1_8054 : (!torch.int) -> !torch.list - %true_8055 = torch.constant.bool true - %none_8056 = torch.constant.none - %6570 = torch.aten.mean.dim %6568, %6569, %true_8055, %none_8056 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6570, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_8057 = torch.constant.float 9.9999997473787516E-6 + %int1_8048 = torch.constant.int 1 + %int1_8049 = torch.constant.int 1 + %int1_8050 = torch.constant.int 1 + %6588 = torch.prim.ListConstruct %int4_8047, %int1_8048, %int1_8049, %int1_8050 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6589 = torch.aten.repeat %6587, %6588 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6589, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %6590 = torch.aten.mul.Tensor %6530, %6583 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6590, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_8051 = torch.constant.int 3 + %int0_8052 = torch.constant.int 0 + %int64_8053 = torch.constant.int 64 + %int1_8054 = torch.constant.int 1 + %6591 = torch.aten.slice.Tensor %6530, %int3_8051, %int0_8052, %int64_8053, %int1_8054 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %6591, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_8055 = torch.constant.int 3 + %int64_8056 = torch.constant.int 64 + %int9223372036854775807_8057 = torch.constant.int 9223372036854775807 %int1_8058 = torch.constant.int 1 - %6571 = torch.aten.add.Scalar %6570, %float9.999990e-06_8057, %int1_8058 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6571, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %6572 = torch.aten.rsqrt %6571 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6572, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %6573 = torch.aten.mul.Tensor %6567, %6572 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6573, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_8059 = torch.constant.int 5 - %6574 = torch.prims.convert_element_type %6573, %int5_8059 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6574, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %6575 = torch.aten.mul.Tensor %285, %6574 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6575, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_8060 = torch.constant.int 5 - %6576 = torch.prims.convert_element_type %6575, %int5_8060 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6576, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_8061 = torch.constant.int -2 - %int-1_8062 = torch.constant.int -1 - %6577 = torch.aten.transpose.int %286, %int-2_8061, %int-1_8062 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_8063 = torch.constant.int 4 - %6578 = torch.aten.mul.int %int4_8063, %306 : !torch.int, !torch.int -> !torch.int - %int4096_8064 = torch.constant.int 4096 - %6579 = torch.prim.ListConstruct %6578, %int4096_8064 : (!torch.int, !torch.int) -> !torch.list - %6580 = torch.aten.view %6576, %6579 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6580, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6581 = torch.aten.mm %6580, %6577 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %6581, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_8065 = torch.constant.int 4 - %int14336_8066 = torch.constant.int 14336 - %6582 = torch.prim.ListConstruct %int4_8065, %306, %int14336_8066 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6583 = torch.aten.view %6581, %6582 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %6583, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %6584 = torch.aten.silu %6583 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %6584, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_8067 = torch.constant.int -2 - %int-1_8068 = torch.constant.int -1 - %6585 = torch.aten.transpose.int %287, %int-2_8067, %int-1_8068 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %6592 = torch.aten.slice.Tensor %6530, %int3_8055, %int64_8056, %int9223372036854775807_8057, %int1_8058 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %6592, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %6593 = torch.aten.neg %6592 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %6593, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %6594 = torch.prim.ListConstruct %6593, %6591 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_8059 = torch.constant.int -1 + %6595 = torch.aten.cat %6594, %int-1_8059 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6595, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %6596 = torch.aten.mul.Tensor %6595, %6589 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6596, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_8060 = torch.constant.int 1 + %6597 = torch.aten.add.Tensor %6590, %6596, %int1_8060 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6597, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_8061 = torch.constant.int 131072 + %none_8062 = torch.constant.none + %none_8063 = torch.constant.none + %cpu_8064 = torch.constant.device "cpu" + %false_8065 = torch.constant.bool false + %6598 = torch.aten.arange %int131072_8061, %none_8062, %none_8063, %cpu_8064, %false_8065 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_8066 = torch.constant.int 0 + %int128_8067 = torch.constant.int 128 + %int2_8068 = torch.constant.int 2 %int4_8069 = torch.constant.int 4 - %6586 = torch.aten.mul.int %int4_8069, %306 : !torch.int, !torch.int -> !torch.int - %int4096_8070 = torch.constant.int 4096 - %6587 = torch.prim.ListConstruct %6586, %int4096_8070 : (!torch.int, !torch.int) -> !torch.list - %6588 = torch.aten.view %6576, %6587 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6588, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6589 = torch.aten.mm %6588, %6585 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %6589, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %int4_8071 = torch.constant.int 4 - %int14336_8072 = torch.constant.int 14336 - %6590 = torch.prim.ListConstruct %int4_8071, %306, %int14336_8072 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6591 = torch.aten.view %6589, %6590 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %6591, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %6592 = torch.aten.mul.Tensor %6584, %6591 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> - torch.bind_symbolic_shape %6592, [%292], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> - %int-2_8073 = torch.constant.int -2 - %int-1_8074 = torch.constant.int -1 - %6593 = torch.aten.transpose.int %288, %int-2_8073, %int-1_8074 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int1_8075 = torch.constant.int 1 - %6594 = torch.aten.size.int %6583, %int1_8075 : !torch.vtensor<[4,?,14336],f16>, !torch.int -> !torch.int - %int4_8076 = torch.constant.int 4 - %6595 = torch.aten.mul.int %int4_8076, %6594 : !torch.int, !torch.int -> !torch.int - %int14336_8077 = torch.constant.int 14336 - %6596 = torch.prim.ListConstruct %6595, %int14336_8077 : (!torch.int, !torch.int) -> !torch.list - %6597 = torch.aten.view %6592, %6596 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> - torch.bind_symbolic_shape %6597, [%292], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> - %6598 = torch.aten.mm %6597, %6593 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6598, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %int4_8078 = torch.constant.int 4 - %int4096_8079 = torch.constant.int 4096 - %6599 = torch.prim.ListConstruct %int4_8078, %6594, %int4096_8079 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6600 = torch.aten.view %6598, %6599 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6600, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int1_8080 = torch.constant.int 1 - %6601 = torch.aten.add.Tensor %6566, %6600, %int1_8080 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6601, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int6_8081 = torch.constant.int 6 - %6602 = torch.prims.convert_element_type %6601, %int6_8081 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6602, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int2_8082 = torch.constant.int 2 - %6603 = torch.aten.pow.Tensor_Scalar %6602, %int2_8082 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6603, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int-1_8083 = torch.constant.int -1 - %6604 = torch.prim.ListConstruct %int-1_8083 : (!torch.int) -> !torch.list - %true_8084 = torch.constant.bool true - %none_8085 = torch.constant.none - %6605 = torch.aten.mean.dim %6603, %6604, %true_8084, %none_8085 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6605, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %float9.999990e-06_8086 = torch.constant.float 9.9999997473787516E-6 + %none_8070 = torch.constant.none + %cpu_8071 = torch.constant.device "cpu" + %false_8072 = torch.constant.bool false + %6599 = torch.aten.arange.start_step %int0_8066, %int128_8067, %int2_8068, %int4_8069, %none_8070, %cpu_8071, %false_8072 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_8073 = torch.constant.int 6 + %6600 = torch.prims.convert_element_type %6599, %int6_8073 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_8074 = torch.constant.int 128 + %6601 = torch.aten.div.Scalar %6600, %int128_8074 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_8075 = torch.constant.float 5.000000e+05 + %6602 = torch.aten.pow.Scalar %float5.000000e05_8075, %6601 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6603 = torch.aten.reciprocal %6602 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_8076 = torch.constant.float 1.000000e+00 + %6604 = torch.aten.mul.Scalar %6603, %float1.000000e00_8076 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %6605 = torch.aten.reciprocal %6604 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_8077 = torch.constant.float 6.2831853071795862 + %6606 = torch.aten.mul.Scalar %6605, %float6.283190e00_8077 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_8078 = torch.constant.float 8.192000e+03 + %6607 = torch.aten.gt.Scalar %6606, %float8.192000e03_8078 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_8079 = torch.constant.int 8 + %6608 = torch.aten.div.Scalar %6604, %int8_8079 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6609 = torch.aten.where.self %6607, %6608, %6604 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6610 = torch.aten.reciprocal %6606 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_8080 = torch.constant.int 8192 + %6611 = torch.aten.mul.Scalar %6610, %int8192_8080 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_8081 = torch.constant.int 1 + %int1_8082 = torch.constant.int 1 + %6612 = torch.aten.sub.Scalar %6611, %int1_8081, %int1_8082 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_8083 = torch.constant.int 3 + %6613 = torch.aten.div.Scalar %6612, %int3_8083 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_8084 = torch.constant.int 1 + %int1_8085 = torch.constant.int 1 + %6614 = torch.aten.rsub.Scalar %6613, %int1_8084, %int1_8085 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %6615 = torch.aten.mul.Tensor %6614, %6609 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_8086 = torch.constant.int 8 + %6616 = torch.aten.div.Scalar %6615, %int8_8086 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6617 = torch.aten.mul.Tensor %6613, %6609 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> %int1_8087 = torch.constant.int 1 - %6606 = torch.aten.add.Scalar %6605, %float9.999990e-06_8086, %int1_8087 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6606, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %6607 = torch.aten.rsqrt %6606 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> - torch.bind_symbolic_shape %6607, [%292], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> - %6608 = torch.aten.mul.Tensor %6602, %6607 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6608, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_8088 = torch.constant.int 5 - %6609 = torch.prims.convert_element_type %6608, %int5_8088 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6609, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %6610 = torch.aten.mul.Tensor %289, %6609 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> - torch.bind_symbolic_shape %6610, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> - %int5_8089 = torch.constant.int 5 - %6611 = torch.prims.convert_element_type %6610, %int5_8089 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> - torch.bind_symbolic_shape %6611, [%292], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> - %int-2_8090 = torch.constant.int -2 - %int-1_8091 = torch.constant.int -1 - %6612 = torch.aten.transpose.int %290, %int-2_8090, %int-1_8091 : !torch.vtensor<[128256,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,128256],f16> - %int4_8092 = torch.constant.int 4 - %6613 = torch.aten.mul.int %int4_8092, %306 : !torch.int, !torch.int -> !torch.int - %int4096_8093 = torch.constant.int 4096 - %6614 = torch.prim.ListConstruct %6613, %int4096_8093 : (!torch.int, !torch.int) -> !torch.list - %6615 = torch.aten.view %6611, %6614 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> - torch.bind_symbolic_shape %6615, [%292], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> - %6616 = torch.aten.mm %6615, %6612 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,128256],f16> -> !torch.vtensor<[?,128256],f16> - torch.bind_symbolic_shape %6616, [%292], affine_map<()[s0] -> (s0 * 128, 128256)> : !torch.vtensor<[?,128256],f16> - %int4_8094 = torch.constant.int 4 + %6618 = torch.aten.add.Tensor %6616, %6617, %int1_8087 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_8088 = torch.constant.float 2.048000e+03 + %6619 = torch.aten.lt.Scalar %6606, %float2.048000e03_8088 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6620 = torch.aten.bitwise_not %6619 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_8089 = torch.constant.float 8.192000e+03 + %6621 = torch.aten.gt.Scalar %6606, %float8.192000e03_8089 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6622 = torch.aten.bitwise_not %6621 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6623 = torch.aten.mul.Tensor %6620, %6622 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6624 = torch.aten.where.self %6623, %6618, %6609 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6625 = torch.prim.ListConstruct %6624, %6624 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_8090 = torch.constant.int -1 + %6626 = torch.aten.cat %6625, %int-1_8090 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_8091 = torch.constant.int 6 + %6627 = torch.prims.convert_element_type %6626, %int6_8091 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_8092 = torch.constant.int 1 + %6628 = torch.aten.unsqueeze %6598, %int1_8092 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_8093 = torch.constant.int 6 + %6629 = torch.prims.convert_element_type %6628, %int6_8093 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_8094 = torch.constant.int 0 + %6630 = torch.aten.unsqueeze %6627, %int0_8094 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_8095 = torch.constant.int 6 + %6631 = torch.prims.convert_element_type %6630, %int6_8095 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %6632 = torch.aten.mul.Tensor %6629, %6631 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %6633 = torch.aten.cos %6632 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_8096 = torch.constant.int 5 + %6634 = torch.prims.convert_element_type %6633, %int5_8096 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %6635 = torch.aten.sin %6632 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_8097 = torch.constant.int 5 + %6636 = torch.prims.convert_element_type %6635, %int5_8097 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_8098 = torch.constant.int 0 + %int0_8099 = torch.constant.int 0 + %int1_8100 = torch.constant.int 1 + %6637 = torch.aten.slice.Tensor %6634, %int0_8098, %int0_8099, %298, %int1_8100 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6637, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_8101 = torch.constant.int 1 + %int0_8102 = torch.constant.int 0 + %int9223372036854775807_8103 = torch.constant.int 9223372036854775807 + %int1_8104 = torch.constant.int 1 + %6638 = torch.aten.slice.Tensor %6637, %int1_8101, %int0_8102, %int9223372036854775807_8103, %int1_8104 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6638, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_8105 = torch.constant.int 0 + %int0_8106 = torch.constant.int 0 + %int1_8107 = torch.constant.int 1 + %6639 = torch.aten.slice.Tensor %6636, %int0_8105, %int0_8106, %298, %int1_8107 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6639, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_8108 = torch.constant.int 1 + %int0_8109 = torch.constant.int 0 + %int9223372036854775807_8110 = torch.constant.int 9223372036854775807 + %int1_8111 = torch.constant.int 1 + %6640 = torch.aten.slice.Tensor %6639, %int1_8108, %int0_8109, %int9223372036854775807_8110, %int1_8111 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6640, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_8112 = torch.constant.int 0 + %6641 = torch.aten.unsqueeze %6638, %int0_8112 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6641, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_8113 = torch.constant.int 1 + %int0_8114 = torch.constant.int 0 + %int9223372036854775807_8115 = torch.constant.int 9223372036854775807 + %int1_8116 = torch.constant.int 1 + %6642 = torch.aten.slice.Tensor %6641, %int1_8113, %int0_8114, %int9223372036854775807_8115, %int1_8116 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6642, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_8117 = torch.constant.int 2 + %6643 = torch.aten.unsqueeze %6642, %int2_8117 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6643, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_8118 = torch.constant.int 3 + %int0_8119 = torch.constant.int 0 + %int9223372036854775807_8120 = torch.constant.int 9223372036854775807 + %int1_8121 = torch.constant.int 1 + %6644 = torch.aten.slice.Tensor %6643, %int3_8118, %int0_8119, %int9223372036854775807_8120, %int1_8121 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6644, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_8122 = torch.constant.int 4 + %int1_8123 = torch.constant.int 1 + %int1_8124 = torch.constant.int 1 + %int1_8125 = torch.constant.int 1 + %6645 = torch.prim.ListConstruct %int4_8122, %int1_8123, %int1_8124, %int1_8125 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6646 = torch.aten.repeat %6644, %6645 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6646, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_8126 = torch.constant.int 0 + %6647 = torch.aten.unsqueeze %6640, %int0_8126 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6647, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_8127 = torch.constant.int 1 + %int0_8128 = torch.constant.int 0 + %int9223372036854775807_8129 = torch.constant.int 9223372036854775807 + %int1_8130 = torch.constant.int 1 + %6648 = torch.aten.slice.Tensor %6647, %int1_8127, %int0_8128, %int9223372036854775807_8129, %int1_8130 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6648, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_8131 = torch.constant.int 2 + %6649 = torch.aten.unsqueeze %6648, %int2_8131 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6649, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_8132 = torch.constant.int 3 + %int0_8133 = torch.constant.int 0 + %int9223372036854775807_8134 = torch.constant.int 9223372036854775807 + %int1_8135 = torch.constant.int 1 + %6650 = torch.aten.slice.Tensor %6649, %int3_8132, %int0_8133, %int9223372036854775807_8134, %int1_8135 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6650, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_8136 = torch.constant.int 4 + %int1_8137 = torch.constant.int 1 + %int1_8138 = torch.constant.int 1 + %int1_8139 = torch.constant.int 1 + %6651 = torch.prim.ListConstruct %int4_8136, %int1_8137, %int1_8138, %int1_8139 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6652 = torch.aten.repeat %6650, %6651 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6652, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %6653 = torch.aten.mul.Tensor %6532, %6646 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6653, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_8140 = torch.constant.int 3 + %int0_8141 = torch.constant.int 0 + %int64_8142 = torch.constant.int 64 + %int1_8143 = torch.constant.int 1 + %6654 = torch.aten.slice.Tensor %6532, %int3_8140, %int0_8141, %int64_8142, %int1_8143 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %6654, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_8144 = torch.constant.int 3 + %int64_8145 = torch.constant.int 64 + %int9223372036854775807_8146 = torch.constant.int 9223372036854775807 + %int1_8147 = torch.constant.int 1 + %6655 = torch.aten.slice.Tensor %6532, %int3_8144, %int64_8145, %int9223372036854775807_8146, %int1_8147 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %6655, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %6656 = torch.aten.neg %6655 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %6656, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %6657 = torch.prim.ListConstruct %6656, %6654 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_8148 = torch.constant.int -1 + %6658 = torch.aten.cat %6657, %int-1_8148 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6658, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %6659 = torch.aten.mul.Tensor %6658, %6652 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6659, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_8149 = torch.constant.int 1 + %6660 = torch.aten.add.Tensor %6653, %6659, %int1_8149 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6660, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_8150 = torch.constant.int 32 + %6661 = torch.aten.mul.Scalar %arg2, %int32_8150 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6661, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int23 = torch.constant.int 23 + %int1_8151 = torch.constant.int 1 + %6662 = torch.aten.add.Scalar %6661, %int23, %int1_8151 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6662, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_8152 = torch.constant.int 2 + %6663 = torch.aten.mul.Scalar %6662, %int2_8152 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6663, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_8153 = torch.constant.int 0 + %int1_8154 = torch.constant.int 1 + %6664 = torch.aten.add.Scalar %6663, %int0_8153, %int1_8154 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6664, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %6665 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %6666 = torch.aten.view %6664, %6665 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %6666, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_8155 = torch.constant.int 4 + %int32_8156 = torch.constant.int 32 + %int8_8157 = torch.constant.int 8 + %int128_8158 = torch.constant.int 128 + %6667 = torch.prim.ListConstruct %int4_8155, %296, %int32_8156, %int8_8157, %int128_8158 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6668 = torch.aten.view %6660, %6667 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6668, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_8159 = torch.constant.int 32 + %int8_8160 = torch.constant.int 8 + %int128_8161 = torch.constant.int 128 + %6669 = torch.prim.ListConstruct %504, %int32_8159, %int8_8160, %int128_8161 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6670 = torch.aten.view %6668, %6669 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %6670, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_8162 = torch.constant.int 1 + %int2_8163 = torch.constant.int 2 + %6671 = torch.aten.transpose.int %6670, %int1_8162, %int2_8163 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6671, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_8164 = torch.constant.int 5 + %6672 = torch.prims.convert_element_type %6671, %int5_8164 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6672, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_8165 = torch.constant.int 32 + %int2_8166 = torch.constant.int 2 + %int8_8167 = torch.constant.int 8 + %int32_8168 = torch.constant.int 32 + %int128_8169 = torch.constant.int 128 + %6673 = torch.prim.ListConstruct %297, %int32_8165, %int2_8166, %int8_8167, %int32_8168, %int128_8169 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6674 = torch.aten.view %6436, %6673 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6674, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_8170 = torch.constant.int 8 + %int32_8171 = torch.constant.int 32 + %int128_8172 = torch.constant.int 128 + %6675 = torch.prim.ListConstruct %497, %int8_8170, %int32_8171, %int128_8172 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6676 = torch.aten.view %6674, %6675 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6676, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %6677 = torch.prim.ListConstruct %6666 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_8173 = torch.constant.bool false + %6678 = torch.aten.index_put %6676, %6677, %6672, %false_8173 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6678, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_8174 = torch.constant.int 32 + %int2_8175 = torch.constant.int 2 + %int8_8176 = torch.constant.int 8 + %int32_8177 = torch.constant.int 32 + %int128_8178 = torch.constant.int 128 + %6679 = torch.prim.ListConstruct %297, %int32_8174, %int2_8175, %int8_8176, %int32_8177, %int128_8178 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6680 = torch.aten.view %6678, %6679 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6680, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_8179 = torch.constant.int 2097152 + %6681 = torch.prim.ListConstruct %297, %int2097152_8179 : (!torch.int, !torch.int) -> !torch.list + %6682 = torch.aten.view %6680, %6681 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6682, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_8180 = torch.constant.int 32 + %int2_8181 = torch.constant.int 2 + %int8_8182 = torch.constant.int 8 + %int32_8183 = torch.constant.int 32 + %int128_8184 = torch.constant.int 128 + %6683 = torch.prim.ListConstruct %297, %int32_8180, %int2_8181, %int8_8182, %int32_8183, %int128_8184 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6684 = torch.aten.view %6682, %6683 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6684, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_8185 = torch.constant.int 8 + %int32_8186 = torch.constant.int 32 + %int128_8187 = torch.constant.int 128 + %6685 = torch.prim.ListConstruct %497, %int8_8185, %int32_8186, %int128_8187 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6686 = torch.aten.view %6684, %6685 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6686, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_8188 = torch.constant.int 32 + %6687 = torch.aten.mul.Scalar %arg2, %int32_8188 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6687, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int23_8189 = torch.constant.int 23 + %int1_8190 = torch.constant.int 1 + %6688 = torch.aten.add.Scalar %6687, %int23_8189, %int1_8190 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6688, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_8191 = torch.constant.int 2 + %6689 = torch.aten.mul.Scalar %6688, %int2_8191 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6689, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_8192 = torch.constant.int 1 + %int1_8193 = torch.constant.int 1 + %6690 = torch.aten.add.Scalar %6689, %int1_8192, %int1_8193 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6690, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %6691 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %6692 = torch.aten.view %6690, %6691 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %6692, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_8194 = torch.constant.int 4 + %int32_8195 = torch.constant.int 32 + %int8_8196 = torch.constant.int 8 + %int128_8197 = torch.constant.int 128 + %6693 = torch.prim.ListConstruct %int4_8194, %296, %int32_8195, %int8_8196, %int128_8197 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6694 = torch.aten.view %6534, %6693 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6694, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_8198 = torch.constant.int 32 + %int8_8199 = torch.constant.int 8 + %int128_8200 = torch.constant.int 128 + %6695 = torch.prim.ListConstruct %504, %int32_8198, %int8_8199, %int128_8200 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6696 = torch.aten.view %6694, %6695 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %6696, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_8201 = torch.constant.int 1 + %int2_8202 = torch.constant.int 2 + %6697 = torch.aten.transpose.int %6696, %int1_8201, %int2_8202 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6697, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_8203 = torch.constant.int 5 + %6698 = torch.prims.convert_element_type %6697, %int5_8203 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6698, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %6699 = torch.prim.ListConstruct %6692 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_8204 = torch.constant.bool false + %6700 = torch.aten.index_put %6686, %6699, %6698, %false_8204 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6700, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_8205 = torch.constant.int 32 + %int2_8206 = torch.constant.int 2 + %int8_8207 = torch.constant.int 8 + %int32_8208 = torch.constant.int 32 + %int128_8209 = torch.constant.int 128 + %6701 = torch.prim.ListConstruct %297, %int32_8205, %int2_8206, %int8_8207, %int32_8208, %int128_8209 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6702 = torch.aten.view %6700, %6701 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6702, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_8210 = torch.constant.int 2097152 + %6703 = torch.prim.ListConstruct %297, %int2097152_8210 : (!torch.int, !torch.int) -> !torch.list + %6704 = torch.aten.view %6702, %6703 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6704, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_8211 = torch.constant.int -2 + %6705 = torch.aten.unsqueeze %6660, %int-2_8211 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6705, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_8212 = torch.constant.int 4 + %int8_8213 = torch.constant.int 8 + %int4_8214 = torch.constant.int 4 + %int128_8215 = torch.constant.int 128 + %6706 = torch.prim.ListConstruct %int4_8212, %298, %int8_8213, %int4_8214, %int128_8215 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_8216 = torch.constant.bool false + %6707 = torch.aten.expand %6705, %6706, %false_8216 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6707, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_8217 = torch.constant.int 0 + %6708 = torch.aten.clone %6707, %int0_8217 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6708, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_8218 = torch.constant.int 4 + %int32_8219 = torch.constant.int 32 + %int128_8220 = torch.constant.int 128 + %6709 = torch.prim.ListConstruct %int4_8218, %298, %int32_8219, %int128_8220 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6710 = torch.aten._unsafe_view %6708, %6709 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6710, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_8221 = torch.constant.int -2 + %6711 = torch.aten.unsqueeze %6534, %int-2_8221 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6711, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_8222 = torch.constant.int 4 + %int8_8223 = torch.constant.int 8 + %int4_8224 = torch.constant.int 4 + %int128_8225 = torch.constant.int 128 + %6712 = torch.prim.ListConstruct %int4_8222, %298, %int8_8223, %int4_8224, %int128_8225 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_8226 = torch.constant.bool false + %6713 = torch.aten.expand %6711, %6712, %false_8226 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6713, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_8227 = torch.constant.int 0 + %6714 = torch.aten.clone %6713, %int0_8227 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6714, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_8228 = torch.constant.int 4 + %int32_8229 = torch.constant.int 32 + %int128_8230 = torch.constant.int 128 + %6715 = torch.prim.ListConstruct %int4_8228, %298, %int32_8229, %int128_8230 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6716 = torch.aten._unsafe_view %6714, %6715 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6716, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_8231 = torch.constant.int 1 + %int2_8232 = torch.constant.int 2 + %6717 = torch.aten.transpose.int %6597, %int1_8231, %int2_8232 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6717, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_8233 = torch.constant.int 1 + %int2_8234 = torch.constant.int 2 + %6718 = torch.aten.transpose.int %6710, %int1_8233, %int2_8234 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6718, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_8235 = torch.constant.int 1 + %int2_8236 = torch.constant.int 2 + %6719 = torch.aten.transpose.int %6716, %int1_8235, %int2_8236 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6719, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_8237 = torch.constant.float 0.000000e+00 + %false_8238 = torch.constant.bool false + %none_8239 = torch.constant.none + %6720:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6717, %6718, %6719, %float0.000000e00_8237, %false_8238, %327, %none_8239) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %6720#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_8240 = torch.constant.int 1 + %int2_8241 = torch.constant.int 2 + %6721 = torch.aten.transpose.int %6720#0, %int1_8240, %int2_8241 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6721, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_8242 = torch.constant.int 4 + %int4096_8243 = torch.constant.int 4096 + %6722 = torch.prim.ListConstruct %int4_8242, %298, %int4096_8243 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6723 = torch.aten.view %6721, %6722 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6723, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_8244 = torch.constant.int -2 + %int-1_8245 = torch.constant.int -1 + %6724 = torch.aten.transpose.int %213, %int-2_8244, %int-1_8245 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_8246 = torch.constant.int 5 + %6725 = torch.prims.convert_element_type %6724, %int5_8246 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_8247 = torch.constant.int 4096 + %6726 = torch.prim.ListConstruct %342, %int4096_8247 : (!torch.int, !torch.int) -> !torch.list + %6727 = torch.aten.view %6723, %6726 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6727, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6728 = torch.aten.mm %6727, %6725 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6728, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_8248 = torch.constant.int 4 + %int4096_8249 = torch.constant.int 4096 + %6729 = torch.prim.ListConstruct %int4_8248, %298, %int4096_8249 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6730 = torch.aten.view %6728, %6729 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6730, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_8250 = torch.constant.int 1 + %6731 = torch.aten.add.Tensor %6497, %6730, %int1_8250 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6731, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_8251 = torch.constant.int 6 + %6732 = torch.prims.convert_element_type %6731, %int6_8251 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6732, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_8252 = torch.constant.int 2 + %6733 = torch.aten.pow.Tensor_Scalar %6732, %int2_8252 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6733, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_8253 = torch.constant.int -1 + %6734 = torch.prim.ListConstruct %int-1_8253 : (!torch.int) -> !torch.list + %true_8254 = torch.constant.bool true + %none_8255 = torch.constant.none + %6735 = torch.aten.mean.dim %6733, %6734, %true_8254, %none_8255 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6735, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_8256 = torch.constant.float 9.9999997473787516E-6 + %int1_8257 = torch.constant.int 1 + %6736 = torch.aten.add.Scalar %6735, %float9.999990e-06_8256, %int1_8257 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6736, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %6737 = torch.aten.rsqrt %6736 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6737, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %6738 = torch.aten.mul.Tensor %6732, %6737 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6738, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_8258 = torch.constant.int 5 + %6739 = torch.prims.convert_element_type %6738, %int5_8258 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6739, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %6740 = torch.aten.mul.Tensor %214, %6739 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6740, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_8259 = torch.constant.int 5 + %6741 = torch.prims.convert_element_type %6740, %int5_8259 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6741, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_8260 = torch.constant.int -2 + %int-1_8261 = torch.constant.int -1 + %6742 = torch.aten.transpose.int %215, %int-2_8260, %int-1_8261 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_8262 = torch.constant.int 5 + %6743 = torch.prims.convert_element_type %6742, %int5_8262 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_8263 = torch.constant.int 4096 + %6744 = torch.prim.ListConstruct %342, %int4096_8263 : (!torch.int, !torch.int) -> !torch.list + %6745 = torch.aten.view %6741, %6744 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6745, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6746 = torch.aten.mm %6745, %6743 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %6746, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_8264 = torch.constant.int 4 + %int14336_8265 = torch.constant.int 14336 + %6747 = torch.prim.ListConstruct %int4_8264, %298, %int14336_8265 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6748 = torch.aten.view %6746, %6747 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %6748, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %6749 = torch.aten.silu %6748 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %6749, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_8266 = torch.constant.int -2 + %int-1_8267 = torch.constant.int -1 + %6750 = torch.aten.transpose.int %216, %int-2_8266, %int-1_8267 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_8268 = torch.constant.int 5 + %6751 = torch.prims.convert_element_type %6750, %int5_8268 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_8269 = torch.constant.int 4096 + %6752 = torch.prim.ListConstruct %342, %int4096_8269 : (!torch.int, !torch.int) -> !torch.list + %6753 = torch.aten.view %6741, %6752 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6753, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6754 = torch.aten.mm %6753, %6751 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %6754, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_8270 = torch.constant.int 4 + %int14336_8271 = torch.constant.int 14336 + %6755 = torch.prim.ListConstruct %int4_8270, %298, %int14336_8271 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6756 = torch.aten.view %6754, %6755 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %6756, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %6757 = torch.aten.mul.Tensor %6749, %6756 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %6757, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_8272 = torch.constant.int -2 + %int-1_8273 = torch.constant.int -1 + %6758 = torch.aten.transpose.int %217, %int-2_8272, %int-1_8273 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_8274 = torch.constant.int 5 + %6759 = torch.prims.convert_element_type %6758, %int5_8274 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_8275 = torch.constant.int 14336 + %6760 = torch.prim.ListConstruct %342, %int14336_8275 : (!torch.int, !torch.int) -> !torch.list + %6761 = torch.aten.view %6757, %6760 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %6761, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %6762 = torch.aten.mm %6761, %6759 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6762, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_8276 = torch.constant.int 4 + %int4096_8277 = torch.constant.int 4096 + %6763 = torch.prim.ListConstruct %int4_8276, %298, %int4096_8277 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6764 = torch.aten.view %6762, %6763 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6764, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_8278 = torch.constant.int 1 + %6765 = torch.aten.add.Tensor %6731, %6764, %int1_8278 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6765, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_8279 = torch.constant.int 6 + %6766 = torch.prims.convert_element_type %6765, %int6_8279 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6766, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_8280 = torch.constant.int 2 + %6767 = torch.aten.pow.Tensor_Scalar %6766, %int2_8280 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6767, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_8281 = torch.constant.int -1 + %6768 = torch.prim.ListConstruct %int-1_8281 : (!torch.int) -> !torch.list + %true_8282 = torch.constant.bool true + %none_8283 = torch.constant.none + %6769 = torch.aten.mean.dim %6767, %6768, %true_8282, %none_8283 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6769, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_8284 = torch.constant.float 9.9999997473787516E-6 + %int1_8285 = torch.constant.int 1 + %6770 = torch.aten.add.Scalar %6769, %float9.999990e-06_8284, %int1_8285 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6770, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %6771 = torch.aten.rsqrt %6770 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %6771, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %6772 = torch.aten.mul.Tensor %6766, %6771 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6772, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_8286 = torch.constant.int 5 + %6773 = torch.prims.convert_element_type %6772, %int5_8286 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6773, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %6774 = torch.aten.mul.Tensor %218, %6773 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %6774, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_8287 = torch.constant.int 5 + %6775 = torch.prims.convert_element_type %6774, %int5_8287 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6775, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_8288 = torch.constant.int -2 + %int-1_8289 = torch.constant.int -1 + %6776 = torch.aten.transpose.int %219, %int-2_8288, %int-1_8289 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_8290 = torch.constant.int 5 + %6777 = torch.prims.convert_element_type %6776, %int5_8290 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_8291 = torch.constant.int 4096 + %6778 = torch.prim.ListConstruct %342, %int4096_8291 : (!torch.int, !torch.int) -> !torch.list + %6779 = torch.aten.view %6775, %6778 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6779, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6780 = torch.aten.mm %6779, %6777 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6780, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_8292 = torch.constant.int 4 + %int4096_8293 = torch.constant.int 4096 + %6781 = torch.prim.ListConstruct %int4_8292, %298, %int4096_8293 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6782 = torch.aten.view %6780, %6781 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6782, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_8294 = torch.constant.int -2 + %int-1_8295 = torch.constant.int -1 + %6783 = torch.aten.transpose.int %220, %int-2_8294, %int-1_8295 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_8296 = torch.constant.int 5 + %6784 = torch.prims.convert_element_type %6783, %int5_8296 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_8297 = torch.constant.int 4096 + %6785 = torch.prim.ListConstruct %342, %int4096_8297 : (!torch.int, !torch.int) -> !torch.list + %6786 = torch.aten.view %6775, %6785 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6786, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6787 = torch.aten.mm %6786, %6784 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %6787, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_8298 = torch.constant.int 4 + %int1024_8299 = torch.constant.int 1024 + %6788 = torch.prim.ListConstruct %int4_8298, %298, %int1024_8299 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6789 = torch.aten.view %6787, %6788 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %6789, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_8300 = torch.constant.int -2 + %int-1_8301 = torch.constant.int -1 + %6790 = torch.aten.transpose.int %221, %int-2_8300, %int-1_8301 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_8302 = torch.constant.int 5 + %6791 = torch.prims.convert_element_type %6790, %int5_8302 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_8303 = torch.constant.int 4096 + %6792 = torch.prim.ListConstruct %342, %int4096_8303 : (!torch.int, !torch.int) -> !torch.list + %6793 = torch.aten.view %6775, %6792 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6793, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6794 = torch.aten.mm %6793, %6791 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %6794, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_8304 = torch.constant.int 4 + %int1024_8305 = torch.constant.int 1024 + %6795 = torch.prim.ListConstruct %int4_8304, %298, %int1024_8305 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6796 = torch.aten.view %6794, %6795 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %6796, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_8306 = torch.constant.int 4 + %int32_8307 = torch.constant.int 32 + %int128_8308 = torch.constant.int 128 + %6797 = torch.prim.ListConstruct %int4_8306, %298, %int32_8307, %int128_8308 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6798 = torch.aten.view %6782, %6797 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6798, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_8309 = torch.constant.int 4 + %int8_8310 = torch.constant.int 8 + %int128_8311 = torch.constant.int 128 + %6799 = torch.prim.ListConstruct %int4_8309, %298, %int8_8310, %int128_8311 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6800 = torch.aten.view %6789, %6799 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6800, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_8312 = torch.constant.int 4 + %int8_8313 = torch.constant.int 8 + %int128_8314 = torch.constant.int 128 + %6801 = torch.prim.ListConstruct %int4_8312, %298, %int8_8313, %int128_8314 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6802 = torch.aten.view %6796, %6801 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6802, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_8315 = torch.constant.int 131072 + %none_8316 = torch.constant.none + %none_8317 = torch.constant.none + %cpu_8318 = torch.constant.device "cpu" + %false_8319 = torch.constant.bool false + %6803 = torch.aten.arange %int131072_8315, %none_8316, %none_8317, %cpu_8318, %false_8319 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_8320 = torch.constant.int 0 + %int128_8321 = torch.constant.int 128 + %int2_8322 = torch.constant.int 2 + %int4_8323 = torch.constant.int 4 + %none_8324 = torch.constant.none + %cpu_8325 = torch.constant.device "cpu" + %false_8326 = torch.constant.bool false + %6804 = torch.aten.arange.start_step %int0_8320, %int128_8321, %int2_8322, %int4_8323, %none_8324, %cpu_8325, %false_8326 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_8327 = torch.constant.int 6 + %6805 = torch.prims.convert_element_type %6804, %int6_8327 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_8328 = torch.constant.int 128 + %6806 = torch.aten.div.Scalar %6805, %int128_8328 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_8329 = torch.constant.float 5.000000e+05 + %6807 = torch.aten.pow.Scalar %float5.000000e05_8329, %6806 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6808 = torch.aten.reciprocal %6807 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_8330 = torch.constant.float 1.000000e+00 + %6809 = torch.aten.mul.Scalar %6808, %float1.000000e00_8330 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %6810 = torch.aten.reciprocal %6809 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_8331 = torch.constant.float 6.2831853071795862 + %6811 = torch.aten.mul.Scalar %6810, %float6.283190e00_8331 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_8332 = torch.constant.float 8.192000e+03 + %6812 = torch.aten.gt.Scalar %6811, %float8.192000e03_8332 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_8333 = torch.constant.int 8 + %6813 = torch.aten.div.Scalar %6809, %int8_8333 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6814 = torch.aten.where.self %6812, %6813, %6809 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6815 = torch.aten.reciprocal %6811 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_8334 = torch.constant.int 8192 + %6816 = torch.aten.mul.Scalar %6815, %int8192_8334 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_8335 = torch.constant.int 1 + %int1_8336 = torch.constant.int 1 + %6817 = torch.aten.sub.Scalar %6816, %int1_8335, %int1_8336 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_8337 = torch.constant.int 3 + %6818 = torch.aten.div.Scalar %6817, %int3_8337 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_8338 = torch.constant.int 1 + %int1_8339 = torch.constant.int 1 + %6819 = torch.aten.rsub.Scalar %6818, %int1_8338, %int1_8339 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %6820 = torch.aten.mul.Tensor %6819, %6814 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_8340 = torch.constant.int 8 + %6821 = torch.aten.div.Scalar %6820, %int8_8340 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6822 = torch.aten.mul.Tensor %6818, %6814 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_8341 = torch.constant.int 1 + %6823 = torch.aten.add.Tensor %6821, %6822, %int1_8341 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_8342 = torch.constant.float 2.048000e+03 + %6824 = torch.aten.lt.Scalar %6811, %float2.048000e03_8342 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6825 = torch.aten.bitwise_not %6824 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_8343 = torch.constant.float 8.192000e+03 + %6826 = torch.aten.gt.Scalar %6811, %float8.192000e03_8343 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6827 = torch.aten.bitwise_not %6826 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6828 = torch.aten.mul.Tensor %6825, %6827 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6829 = torch.aten.where.self %6828, %6823, %6814 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6830 = torch.prim.ListConstruct %6829, %6829 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_8344 = torch.constant.int -1 + %6831 = torch.aten.cat %6830, %int-1_8344 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_8345 = torch.constant.int 6 + %6832 = torch.prims.convert_element_type %6831, %int6_8345 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_8346 = torch.constant.int 1 + %6833 = torch.aten.unsqueeze %6803, %int1_8346 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_8347 = torch.constant.int 6 + %6834 = torch.prims.convert_element_type %6833, %int6_8347 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_8348 = torch.constant.int 0 + %6835 = torch.aten.unsqueeze %6832, %int0_8348 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_8349 = torch.constant.int 6 + %6836 = torch.prims.convert_element_type %6835, %int6_8349 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %6837 = torch.aten.mul.Tensor %6834, %6836 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %6838 = torch.aten.cos %6837 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_8350 = torch.constant.int 5 + %6839 = torch.prims.convert_element_type %6838, %int5_8350 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %6840 = torch.aten.sin %6837 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_8351 = torch.constant.int 5 + %6841 = torch.prims.convert_element_type %6840, %int5_8351 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_8352 = torch.constant.int 0 + %int0_8353 = torch.constant.int 0 + %int1_8354 = torch.constant.int 1 + %6842 = torch.aten.slice.Tensor %6839, %int0_8352, %int0_8353, %298, %int1_8354 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6842, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_8355 = torch.constant.int 1 + %int0_8356 = torch.constant.int 0 + %int9223372036854775807_8357 = torch.constant.int 9223372036854775807 + %int1_8358 = torch.constant.int 1 + %6843 = torch.aten.slice.Tensor %6842, %int1_8355, %int0_8356, %int9223372036854775807_8357, %int1_8358 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6843, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_8359 = torch.constant.int 0 + %int0_8360 = torch.constant.int 0 + %int1_8361 = torch.constant.int 1 + %6844 = torch.aten.slice.Tensor %6841, %int0_8359, %int0_8360, %298, %int1_8361 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6844, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_8362 = torch.constant.int 1 + %int0_8363 = torch.constant.int 0 + %int9223372036854775807_8364 = torch.constant.int 9223372036854775807 + %int1_8365 = torch.constant.int 1 + %6845 = torch.aten.slice.Tensor %6844, %int1_8362, %int0_8363, %int9223372036854775807_8364, %int1_8365 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6845, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_8366 = torch.constant.int 0 + %6846 = torch.aten.unsqueeze %6843, %int0_8366 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6846, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_8367 = torch.constant.int 1 + %int0_8368 = torch.constant.int 0 + %int9223372036854775807_8369 = torch.constant.int 9223372036854775807 + %int1_8370 = torch.constant.int 1 + %6847 = torch.aten.slice.Tensor %6846, %int1_8367, %int0_8368, %int9223372036854775807_8369, %int1_8370 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6847, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_8371 = torch.constant.int 2 + %6848 = torch.aten.unsqueeze %6847, %int2_8371 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6848, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_8372 = torch.constant.int 3 + %int0_8373 = torch.constant.int 0 + %int9223372036854775807_8374 = torch.constant.int 9223372036854775807 + %int1_8375 = torch.constant.int 1 + %6849 = torch.aten.slice.Tensor %6848, %int3_8372, %int0_8373, %int9223372036854775807_8374, %int1_8375 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6849, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_8376 = torch.constant.int 4 + %int1_8377 = torch.constant.int 1 + %int1_8378 = torch.constant.int 1 + %int1_8379 = torch.constant.int 1 + %6850 = torch.prim.ListConstruct %int4_8376, %int1_8377, %int1_8378, %int1_8379 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6851 = torch.aten.repeat %6849, %6850 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6851, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_8380 = torch.constant.int 0 + %6852 = torch.aten.unsqueeze %6845, %int0_8380 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6852, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_8381 = torch.constant.int 1 + %int0_8382 = torch.constant.int 0 + %int9223372036854775807_8383 = torch.constant.int 9223372036854775807 + %int1_8384 = torch.constant.int 1 + %6853 = torch.aten.slice.Tensor %6852, %int1_8381, %int0_8382, %int9223372036854775807_8383, %int1_8384 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6853, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_8385 = torch.constant.int 2 + %6854 = torch.aten.unsqueeze %6853, %int2_8385 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6854, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_8386 = torch.constant.int 3 + %int0_8387 = torch.constant.int 0 + %int9223372036854775807_8388 = torch.constant.int 9223372036854775807 + %int1_8389 = torch.constant.int 1 + %6855 = torch.aten.slice.Tensor %6854, %int3_8386, %int0_8387, %int9223372036854775807_8388, %int1_8389 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6855, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_8390 = torch.constant.int 4 + %int1_8391 = torch.constant.int 1 + %int1_8392 = torch.constant.int 1 + %int1_8393 = torch.constant.int 1 + %6856 = torch.prim.ListConstruct %int4_8390, %int1_8391, %int1_8392, %int1_8393 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6857 = torch.aten.repeat %6855, %6856 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6857, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %6858 = torch.aten.mul.Tensor %6798, %6851 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6858, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_8394 = torch.constant.int 3 + %int0_8395 = torch.constant.int 0 + %int64_8396 = torch.constant.int 64 + %int1_8397 = torch.constant.int 1 + %6859 = torch.aten.slice.Tensor %6798, %int3_8394, %int0_8395, %int64_8396, %int1_8397 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %6859, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_8398 = torch.constant.int 3 + %int64_8399 = torch.constant.int 64 + %int9223372036854775807_8400 = torch.constant.int 9223372036854775807 + %int1_8401 = torch.constant.int 1 + %6860 = torch.aten.slice.Tensor %6798, %int3_8398, %int64_8399, %int9223372036854775807_8400, %int1_8401 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %6860, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %6861 = torch.aten.neg %6860 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %6861, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %6862 = torch.prim.ListConstruct %6861, %6859 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_8402 = torch.constant.int -1 + %6863 = torch.aten.cat %6862, %int-1_8402 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6863, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %6864 = torch.aten.mul.Tensor %6863, %6857 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6864, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_8403 = torch.constant.int 1 + %6865 = torch.aten.add.Tensor %6858, %6864, %int1_8403 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6865, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_8404 = torch.constant.int 131072 + %none_8405 = torch.constant.none + %none_8406 = torch.constant.none + %cpu_8407 = torch.constant.device "cpu" + %false_8408 = torch.constant.bool false + %6866 = torch.aten.arange %int131072_8404, %none_8405, %none_8406, %cpu_8407, %false_8408 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_8409 = torch.constant.int 0 + %int128_8410 = torch.constant.int 128 + %int2_8411 = torch.constant.int 2 + %int4_8412 = torch.constant.int 4 + %none_8413 = torch.constant.none + %cpu_8414 = torch.constant.device "cpu" + %false_8415 = torch.constant.bool false + %6867 = torch.aten.arange.start_step %int0_8409, %int128_8410, %int2_8411, %int4_8412, %none_8413, %cpu_8414, %false_8415 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_8416 = torch.constant.int 6 + %6868 = torch.prims.convert_element_type %6867, %int6_8416 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_8417 = torch.constant.int 128 + %6869 = torch.aten.div.Scalar %6868, %int128_8417 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_8418 = torch.constant.float 5.000000e+05 + %6870 = torch.aten.pow.Scalar %float5.000000e05_8418, %6869 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6871 = torch.aten.reciprocal %6870 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_8419 = torch.constant.float 1.000000e+00 + %6872 = torch.aten.mul.Scalar %6871, %float1.000000e00_8419 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %6873 = torch.aten.reciprocal %6872 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_8420 = torch.constant.float 6.2831853071795862 + %6874 = torch.aten.mul.Scalar %6873, %float6.283190e00_8420 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_8421 = torch.constant.float 8.192000e+03 + %6875 = torch.aten.gt.Scalar %6874, %float8.192000e03_8421 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_8422 = torch.constant.int 8 + %6876 = torch.aten.div.Scalar %6872, %int8_8422 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6877 = torch.aten.where.self %6875, %6876, %6872 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6878 = torch.aten.reciprocal %6874 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_8423 = torch.constant.int 8192 + %6879 = torch.aten.mul.Scalar %6878, %int8192_8423 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_8424 = torch.constant.int 1 + %int1_8425 = torch.constant.int 1 + %6880 = torch.aten.sub.Scalar %6879, %int1_8424, %int1_8425 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_8426 = torch.constant.int 3 + %6881 = torch.aten.div.Scalar %6880, %int3_8426 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_8427 = torch.constant.int 1 + %int1_8428 = torch.constant.int 1 + %6882 = torch.aten.rsub.Scalar %6881, %int1_8427, %int1_8428 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %6883 = torch.aten.mul.Tensor %6882, %6877 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_8429 = torch.constant.int 8 + %6884 = torch.aten.div.Scalar %6883, %int8_8429 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %6885 = torch.aten.mul.Tensor %6881, %6877 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_8430 = torch.constant.int 1 + %6886 = torch.aten.add.Tensor %6884, %6885, %int1_8430 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_8431 = torch.constant.float 2.048000e+03 + %6887 = torch.aten.lt.Scalar %6874, %float2.048000e03_8431 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6888 = torch.aten.bitwise_not %6887 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_8432 = torch.constant.float 8.192000e+03 + %6889 = torch.aten.gt.Scalar %6874, %float8.192000e03_8432 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %6890 = torch.aten.bitwise_not %6889 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6891 = torch.aten.mul.Tensor %6888, %6890 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %6892 = torch.aten.where.self %6891, %6886, %6877 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %6893 = torch.prim.ListConstruct %6892, %6892 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_8433 = torch.constant.int -1 + %6894 = torch.aten.cat %6893, %int-1_8433 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_8434 = torch.constant.int 6 + %6895 = torch.prims.convert_element_type %6894, %int6_8434 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_8435 = torch.constant.int 1 + %6896 = torch.aten.unsqueeze %6866, %int1_8435 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_8436 = torch.constant.int 6 + %6897 = torch.prims.convert_element_type %6896, %int6_8436 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_8437 = torch.constant.int 0 + %6898 = torch.aten.unsqueeze %6895, %int0_8437 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_8438 = torch.constant.int 6 + %6899 = torch.prims.convert_element_type %6898, %int6_8438 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %6900 = torch.aten.mul.Tensor %6897, %6899 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %6901 = torch.aten.cos %6900 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_8439 = torch.constant.int 5 + %6902 = torch.prims.convert_element_type %6901, %int5_8439 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %6903 = torch.aten.sin %6900 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_8440 = torch.constant.int 5 + %6904 = torch.prims.convert_element_type %6903, %int5_8440 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_8441 = torch.constant.int 0 + %int0_8442 = torch.constant.int 0 + %int1_8443 = torch.constant.int 1 + %6905 = torch.aten.slice.Tensor %6902, %int0_8441, %int0_8442, %298, %int1_8443 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6905, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_8444 = torch.constant.int 1 + %int0_8445 = torch.constant.int 0 + %int9223372036854775807_8446 = torch.constant.int 9223372036854775807 + %int1_8447 = torch.constant.int 1 + %6906 = torch.aten.slice.Tensor %6905, %int1_8444, %int0_8445, %int9223372036854775807_8446, %int1_8447 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6906, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_8448 = torch.constant.int 0 + %int0_8449 = torch.constant.int 0 + %int1_8450 = torch.constant.int 1 + %6907 = torch.aten.slice.Tensor %6904, %int0_8448, %int0_8449, %298, %int1_8450 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6907, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_8451 = torch.constant.int 1 + %int0_8452 = torch.constant.int 0 + %int9223372036854775807_8453 = torch.constant.int 9223372036854775807 + %int1_8454 = torch.constant.int 1 + %6908 = torch.aten.slice.Tensor %6907, %int1_8451, %int0_8452, %int9223372036854775807_8453, %int1_8454 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6908, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_8455 = torch.constant.int 0 + %6909 = torch.aten.unsqueeze %6906, %int0_8455 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6909, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_8456 = torch.constant.int 1 + %int0_8457 = torch.constant.int 0 + %int9223372036854775807_8458 = torch.constant.int 9223372036854775807 + %int1_8459 = torch.constant.int 1 + %6910 = torch.aten.slice.Tensor %6909, %int1_8456, %int0_8457, %int9223372036854775807_8458, %int1_8459 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6910, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_8460 = torch.constant.int 2 + %6911 = torch.aten.unsqueeze %6910, %int2_8460 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6911, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_8461 = torch.constant.int 3 + %int0_8462 = torch.constant.int 0 + %int9223372036854775807_8463 = torch.constant.int 9223372036854775807 + %int1_8464 = torch.constant.int 1 + %6912 = torch.aten.slice.Tensor %6911, %int3_8461, %int0_8462, %int9223372036854775807_8463, %int1_8464 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6912, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_8465 = torch.constant.int 4 + %int1_8466 = torch.constant.int 1 + %int1_8467 = torch.constant.int 1 + %int1_8468 = torch.constant.int 1 + %6913 = torch.prim.ListConstruct %int4_8465, %int1_8466, %int1_8467, %int1_8468 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6914 = torch.aten.repeat %6912, %6913 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6914, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_8469 = torch.constant.int 0 + %6915 = torch.aten.unsqueeze %6908, %int0_8469 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6915, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_8470 = torch.constant.int 1 + %int0_8471 = torch.constant.int 0 + %int9223372036854775807_8472 = torch.constant.int 9223372036854775807 + %int1_8473 = torch.constant.int 1 + %6916 = torch.aten.slice.Tensor %6915, %int1_8470, %int0_8471, %int9223372036854775807_8472, %int1_8473 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %6916, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_8474 = torch.constant.int 2 + %6917 = torch.aten.unsqueeze %6916, %int2_8474 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6917, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_8475 = torch.constant.int 3 + %int0_8476 = torch.constant.int 0 + %int9223372036854775807_8477 = torch.constant.int 9223372036854775807 + %int1_8478 = torch.constant.int 1 + %6918 = torch.aten.slice.Tensor %6917, %int3_8475, %int0_8476, %int9223372036854775807_8477, %int1_8478 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %6918, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_8479 = torch.constant.int 4 + %int1_8480 = torch.constant.int 1 + %int1_8481 = torch.constant.int 1 + %int1_8482 = torch.constant.int 1 + %6919 = torch.prim.ListConstruct %int4_8479, %int1_8480, %int1_8481, %int1_8482 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6920 = torch.aten.repeat %6918, %6919 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %6920, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %6921 = torch.aten.mul.Tensor %6800, %6914 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6921, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_8483 = torch.constant.int 3 + %int0_8484 = torch.constant.int 0 + %int64_8485 = torch.constant.int 64 + %int1_8486 = torch.constant.int 1 + %6922 = torch.aten.slice.Tensor %6800, %int3_8483, %int0_8484, %int64_8485, %int1_8486 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %6922, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_8487 = torch.constant.int 3 + %int64_8488 = torch.constant.int 64 + %int9223372036854775807_8489 = torch.constant.int 9223372036854775807 + %int1_8490 = torch.constant.int 1 + %6923 = torch.aten.slice.Tensor %6800, %int3_8487, %int64_8488, %int9223372036854775807_8489, %int1_8490 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %6923, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %6924 = torch.aten.neg %6923 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %6924, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %6925 = torch.prim.ListConstruct %6924, %6922 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_8491 = torch.constant.int -1 + %6926 = torch.aten.cat %6925, %int-1_8491 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6926, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %6927 = torch.aten.mul.Tensor %6926, %6920 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6927, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_8492 = torch.constant.int 1 + %6928 = torch.aten.add.Tensor %6921, %6927, %int1_8492 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6928, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_8493 = torch.constant.int 32 + %6929 = torch.aten.mul.Scalar %arg2, %int32_8493 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6929, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int24 = torch.constant.int 24 + %int1_8494 = torch.constant.int 1 + %6930 = torch.aten.add.Scalar %6929, %int24, %int1_8494 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6930, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_8495 = torch.constant.int 2 + %6931 = torch.aten.mul.Scalar %6930, %int2_8495 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6931, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_8496 = torch.constant.int 0 + %int1_8497 = torch.constant.int 1 + %6932 = torch.aten.add.Scalar %6931, %int0_8496, %int1_8497 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6932, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %6933 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %6934 = torch.aten.view %6932, %6933 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %6934, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_8498 = torch.constant.int 4 + %int32_8499 = torch.constant.int 32 + %int8_8500 = torch.constant.int 8 + %int128_8501 = torch.constant.int 128 + %6935 = torch.prim.ListConstruct %int4_8498, %296, %int32_8499, %int8_8500, %int128_8501 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6936 = torch.aten.view %6928, %6935 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6936, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_8502 = torch.constant.int 32 + %int8_8503 = torch.constant.int 8 + %int128_8504 = torch.constant.int 128 + %6937 = torch.prim.ListConstruct %504, %int32_8502, %int8_8503, %int128_8504 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6938 = torch.aten.view %6936, %6937 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %6938, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_8505 = torch.constant.int 1 + %int2_8506 = torch.constant.int 2 + %6939 = torch.aten.transpose.int %6938, %int1_8505, %int2_8506 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6939, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_8507 = torch.constant.int 5 + %6940 = torch.prims.convert_element_type %6939, %int5_8507 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6940, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_8508 = torch.constant.int 32 + %int2_8509 = torch.constant.int 2 + %int8_8510 = torch.constant.int 8 + %int32_8511 = torch.constant.int 32 + %int128_8512 = torch.constant.int 128 + %6941 = torch.prim.ListConstruct %297, %int32_8508, %int2_8509, %int8_8510, %int32_8511, %int128_8512 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6942 = torch.aten.view %6704, %6941 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6942, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_8513 = torch.constant.int 8 + %int32_8514 = torch.constant.int 32 + %int128_8515 = torch.constant.int 128 + %6943 = torch.prim.ListConstruct %497, %int8_8513, %int32_8514, %int128_8515 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6944 = torch.aten.view %6942, %6943 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6944, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %6945 = torch.prim.ListConstruct %6934 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_8516 = torch.constant.bool false + %6946 = torch.aten.index_put %6944, %6945, %6940, %false_8516 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6946, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_8517 = torch.constant.int 32 + %int2_8518 = torch.constant.int 2 + %int8_8519 = torch.constant.int 8 + %int32_8520 = torch.constant.int 32 + %int128_8521 = torch.constant.int 128 + %6947 = torch.prim.ListConstruct %297, %int32_8517, %int2_8518, %int8_8519, %int32_8520, %int128_8521 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6948 = torch.aten.view %6946, %6947 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6948, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_8522 = torch.constant.int 2097152 + %6949 = torch.prim.ListConstruct %297, %int2097152_8522 : (!torch.int, !torch.int) -> !torch.list + %6950 = torch.aten.view %6948, %6949 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6950, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_8523 = torch.constant.int 32 + %int2_8524 = torch.constant.int 2 + %int8_8525 = torch.constant.int 8 + %int32_8526 = torch.constant.int 32 + %int128_8527 = torch.constant.int 128 + %6951 = torch.prim.ListConstruct %297, %int32_8523, %int2_8524, %int8_8525, %int32_8526, %int128_8527 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6952 = torch.aten.view %6950, %6951 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6952, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_8528 = torch.constant.int 8 + %int32_8529 = torch.constant.int 32 + %int128_8530 = torch.constant.int 128 + %6953 = torch.prim.ListConstruct %497, %int8_8528, %int32_8529, %int128_8530 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6954 = torch.aten.view %6952, %6953 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6954, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_8531 = torch.constant.int 32 + %6955 = torch.aten.mul.Scalar %arg2, %int32_8531 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6955, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int24_8532 = torch.constant.int 24 + %int1_8533 = torch.constant.int 1 + %6956 = torch.aten.add.Scalar %6955, %int24_8532, %int1_8533 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6956, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_8534 = torch.constant.int 2 + %6957 = torch.aten.mul.Scalar %6956, %int2_8534 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6957, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_8535 = torch.constant.int 1 + %int1_8536 = torch.constant.int 1 + %6958 = torch.aten.add.Scalar %6957, %int1_8535, %int1_8536 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %6958, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %6959 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %6960 = torch.aten.view %6958, %6959 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %6960, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_8537 = torch.constant.int 4 + %int32_8538 = torch.constant.int 32 + %int8_8539 = torch.constant.int 8 + %int128_8540 = torch.constant.int 128 + %6961 = torch.prim.ListConstruct %int4_8537, %296, %int32_8538, %int8_8539, %int128_8540 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6962 = torch.aten.view %6802, %6961 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6962, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_8541 = torch.constant.int 32 + %int8_8542 = torch.constant.int 8 + %int128_8543 = torch.constant.int 128 + %6963 = torch.prim.ListConstruct %504, %int32_8541, %int8_8542, %int128_8543 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6964 = torch.aten.view %6962, %6963 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %6964, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_8544 = torch.constant.int 1 + %int2_8545 = torch.constant.int 2 + %6965 = torch.aten.transpose.int %6964, %int1_8544, %int2_8545 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6965, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_8546 = torch.constant.int 5 + %6966 = torch.prims.convert_element_type %6965, %int5_8546 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6966, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %6967 = torch.prim.ListConstruct %6960 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_8547 = torch.constant.bool false + %6968 = torch.aten.index_put %6954, %6967, %6966, %false_8547 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %6968, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_8548 = torch.constant.int 32 + %int2_8549 = torch.constant.int 2 + %int8_8550 = torch.constant.int 8 + %int32_8551 = torch.constant.int 32 + %int128_8552 = torch.constant.int 128 + %6969 = torch.prim.ListConstruct %297, %int32_8548, %int2_8549, %int8_8550, %int32_8551, %int128_8552 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6970 = torch.aten.view %6968, %6969 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6970, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_8553 = torch.constant.int 2097152 + %6971 = torch.prim.ListConstruct %297, %int2097152_8553 : (!torch.int, !torch.int) -> !torch.list + %6972 = torch.aten.view %6970, %6971 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6972, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_8554 = torch.constant.int -2 + %6973 = torch.aten.unsqueeze %6928, %int-2_8554 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6973, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_8555 = torch.constant.int 4 + %int8_8556 = torch.constant.int 8 + %int4_8557 = torch.constant.int 4 + %int128_8558 = torch.constant.int 128 + %6974 = torch.prim.ListConstruct %int4_8555, %298, %int8_8556, %int4_8557, %int128_8558 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_8559 = torch.constant.bool false + %6975 = torch.aten.expand %6973, %6974, %false_8559 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6975, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_8560 = torch.constant.int 0 + %6976 = torch.aten.clone %6975, %int0_8560 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6976, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_8561 = torch.constant.int 4 + %int32_8562 = torch.constant.int 32 + %int128_8563 = torch.constant.int 128 + %6977 = torch.prim.ListConstruct %int4_8561, %298, %int32_8562, %int128_8563 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6978 = torch.aten._unsafe_view %6976, %6977 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6978, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_8564 = torch.constant.int -2 + %6979 = torch.aten.unsqueeze %6802, %int-2_8564 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6979, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_8565 = torch.constant.int 4 + %int8_8566 = torch.constant.int 8 + %int4_8567 = torch.constant.int 4 + %int128_8568 = torch.constant.int 128 + %6980 = torch.prim.ListConstruct %int4_8565, %298, %int8_8566, %int4_8567, %int128_8568 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_8569 = torch.constant.bool false + %6981 = torch.aten.expand %6979, %6980, %false_8569 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6981, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_8570 = torch.constant.int 0 + %6982 = torch.aten.clone %6981, %int0_8570 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6982, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_8571 = torch.constant.int 4 + %int32_8572 = torch.constant.int 32 + %int128_8573 = torch.constant.int 128 + %6983 = torch.prim.ListConstruct %int4_8571, %298, %int32_8572, %int128_8573 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6984 = torch.aten._unsafe_view %6982, %6983 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6984, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_8574 = torch.constant.int 1 + %int2_8575 = torch.constant.int 2 + %6985 = torch.aten.transpose.int %6865, %int1_8574, %int2_8575 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6985, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_8576 = torch.constant.int 1 + %int2_8577 = torch.constant.int 2 + %6986 = torch.aten.transpose.int %6978, %int1_8576, %int2_8577 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6986, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_8578 = torch.constant.int 1 + %int2_8579 = torch.constant.int 2 + %6987 = torch.aten.transpose.int %6984, %int1_8578, %int2_8579 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6987, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_8580 = torch.constant.float 0.000000e+00 + %false_8581 = torch.constant.bool false + %none_8582 = torch.constant.none + %6988:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6985, %6986, %6987, %float0.000000e00_8580, %false_8581, %327, %none_8582) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %6988#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_8583 = torch.constant.int 1 + %int2_8584 = torch.constant.int 2 + %6989 = torch.aten.transpose.int %6988#0, %int1_8583, %int2_8584 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6989, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_8585 = torch.constant.int 4 + %int4096_8586 = torch.constant.int 4096 + %6990 = torch.prim.ListConstruct %int4_8585, %298, %int4096_8586 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6991 = torch.aten.view %6989, %6990 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6991, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_8587 = torch.constant.int -2 + %int-1_8588 = torch.constant.int -1 + %6992 = torch.aten.transpose.int %222, %int-2_8587, %int-1_8588 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_8589 = torch.constant.int 5 + %6993 = torch.prims.convert_element_type %6992, %int5_8589 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_8590 = torch.constant.int 4096 + %6994 = torch.prim.ListConstruct %342, %int4096_8590 : (!torch.int, !torch.int) -> !torch.list + %6995 = torch.aten.view %6991, %6994 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6995, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %6996 = torch.aten.mm %6995, %6993 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %6996, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_8591 = torch.constant.int 4 + %int4096_8592 = torch.constant.int 4096 + %6997 = torch.prim.ListConstruct %int4_8591, %298, %int4096_8592 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6998 = torch.aten.view %6996, %6997 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6998, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_8593 = torch.constant.int 1 + %6999 = torch.aten.add.Tensor %6765, %6998, %int1_8593 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %6999, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_8594 = torch.constant.int 6 + %7000 = torch.prims.convert_element_type %6999, %int6_8594 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7000, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_8595 = torch.constant.int 2 + %7001 = torch.aten.pow.Tensor_Scalar %7000, %int2_8595 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7001, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_8596 = torch.constant.int -1 + %7002 = torch.prim.ListConstruct %int-1_8596 : (!torch.int) -> !torch.list + %true_8597 = torch.constant.bool true + %none_8598 = torch.constant.none + %7003 = torch.aten.mean.dim %7001, %7002, %true_8597, %none_8598 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7003, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_8599 = torch.constant.float 9.9999997473787516E-6 + %int1_8600 = torch.constant.int 1 + %7004 = torch.aten.add.Scalar %7003, %float9.999990e-06_8599, %int1_8600 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7004, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7005 = torch.aten.rsqrt %7004 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7005, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7006 = torch.aten.mul.Tensor %7000, %7005 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7006, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_8601 = torch.constant.int 5 + %7007 = torch.prims.convert_element_type %7006, %int5_8601 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7007, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %7008 = torch.aten.mul.Tensor %223, %7007 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7008, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_8602 = torch.constant.int 5 + %7009 = torch.prims.convert_element_type %7008, %int5_8602 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7009, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_8603 = torch.constant.int -2 + %int-1_8604 = torch.constant.int -1 + %7010 = torch.aten.transpose.int %224, %int-2_8603, %int-1_8604 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_8605 = torch.constant.int 5 + %7011 = torch.prims.convert_element_type %7010, %int5_8605 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_8606 = torch.constant.int 4096 + %7012 = torch.prim.ListConstruct %342, %int4096_8606 : (!torch.int, !torch.int) -> !torch.list + %7013 = torch.aten.view %7009, %7012 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7013, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7014 = torch.aten.mm %7013, %7011 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %7014, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_8607 = torch.constant.int 4 + %int14336_8608 = torch.constant.int 14336 + %7015 = torch.prim.ListConstruct %int4_8607, %298, %int14336_8608 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7016 = torch.aten.view %7014, %7015 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7016, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %7017 = torch.aten.silu %7016 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7017, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_8609 = torch.constant.int -2 + %int-1_8610 = torch.constant.int -1 + %7018 = torch.aten.transpose.int %225, %int-2_8609, %int-1_8610 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_8611 = torch.constant.int 5 + %7019 = torch.prims.convert_element_type %7018, %int5_8611 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_8612 = torch.constant.int 4096 + %7020 = torch.prim.ListConstruct %342, %int4096_8612 : (!torch.int, !torch.int) -> !torch.list + %7021 = torch.aten.view %7009, %7020 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7021, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7022 = torch.aten.mm %7021, %7019 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %7022, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_8613 = torch.constant.int 4 + %int14336_8614 = torch.constant.int 14336 + %7023 = torch.prim.ListConstruct %int4_8613, %298, %int14336_8614 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7024 = torch.aten.view %7022, %7023 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7024, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %7025 = torch.aten.mul.Tensor %7017, %7024 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7025, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_8615 = torch.constant.int -2 + %int-1_8616 = torch.constant.int -1 + %7026 = torch.aten.transpose.int %226, %int-2_8615, %int-1_8616 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_8617 = torch.constant.int 5 + %7027 = torch.prims.convert_element_type %7026, %int5_8617 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_8618 = torch.constant.int 14336 + %7028 = torch.prim.ListConstruct %342, %int14336_8618 : (!torch.int, !torch.int) -> !torch.list + %7029 = torch.aten.view %7025, %7028 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %7029, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %7030 = torch.aten.mm %7029, %7027 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7030, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_8619 = torch.constant.int 4 + %int4096_8620 = torch.constant.int 4096 + %7031 = torch.prim.ListConstruct %int4_8619, %298, %int4096_8620 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7032 = torch.aten.view %7030, %7031 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7032, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_8621 = torch.constant.int 1 + %7033 = torch.aten.add.Tensor %6999, %7032, %int1_8621 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7033, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_8622 = torch.constant.int 6 + %7034 = torch.prims.convert_element_type %7033, %int6_8622 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7034, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_8623 = torch.constant.int 2 + %7035 = torch.aten.pow.Tensor_Scalar %7034, %int2_8623 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7035, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_8624 = torch.constant.int -1 + %7036 = torch.prim.ListConstruct %int-1_8624 : (!torch.int) -> !torch.list + %true_8625 = torch.constant.bool true + %none_8626 = torch.constant.none + %7037 = torch.aten.mean.dim %7035, %7036, %true_8625, %none_8626 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7037, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_8627 = torch.constant.float 9.9999997473787516E-6 + %int1_8628 = torch.constant.int 1 + %7038 = torch.aten.add.Scalar %7037, %float9.999990e-06_8627, %int1_8628 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7038, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7039 = torch.aten.rsqrt %7038 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7039, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7040 = torch.aten.mul.Tensor %7034, %7039 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7040, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_8629 = torch.constant.int 5 + %7041 = torch.prims.convert_element_type %7040, %int5_8629 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7041, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %7042 = torch.aten.mul.Tensor %227, %7041 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7042, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_8630 = torch.constant.int 5 + %7043 = torch.prims.convert_element_type %7042, %int5_8630 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7043, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_8631 = torch.constant.int -2 + %int-1_8632 = torch.constant.int -1 + %7044 = torch.aten.transpose.int %228, %int-2_8631, %int-1_8632 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_8633 = torch.constant.int 5 + %7045 = torch.prims.convert_element_type %7044, %int5_8633 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_8634 = torch.constant.int 4096 + %7046 = torch.prim.ListConstruct %342, %int4096_8634 : (!torch.int, !torch.int) -> !torch.list + %7047 = torch.aten.view %7043, %7046 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7047, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7048 = torch.aten.mm %7047, %7045 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7048, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_8635 = torch.constant.int 4 + %int4096_8636 = torch.constant.int 4096 + %7049 = torch.prim.ListConstruct %int4_8635, %298, %int4096_8636 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7050 = torch.aten.view %7048, %7049 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7050, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_8637 = torch.constant.int -2 + %int-1_8638 = torch.constant.int -1 + %7051 = torch.aten.transpose.int %229, %int-2_8637, %int-1_8638 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_8639 = torch.constant.int 5 + %7052 = torch.prims.convert_element_type %7051, %int5_8639 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_8640 = torch.constant.int 4096 + %7053 = torch.prim.ListConstruct %342, %int4096_8640 : (!torch.int, !torch.int) -> !torch.list + %7054 = torch.aten.view %7043, %7053 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7054, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7055 = torch.aten.mm %7054, %7052 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %7055, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_8641 = torch.constant.int 4 + %int1024_8642 = torch.constant.int 1024 + %7056 = torch.prim.ListConstruct %int4_8641, %298, %int1024_8642 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7057 = torch.aten.view %7055, %7056 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %7057, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_8643 = torch.constant.int -2 + %int-1_8644 = torch.constant.int -1 + %7058 = torch.aten.transpose.int %230, %int-2_8643, %int-1_8644 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_8645 = torch.constant.int 5 + %7059 = torch.prims.convert_element_type %7058, %int5_8645 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_8646 = torch.constant.int 4096 + %7060 = torch.prim.ListConstruct %342, %int4096_8646 : (!torch.int, !torch.int) -> !torch.list + %7061 = torch.aten.view %7043, %7060 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7061, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7062 = torch.aten.mm %7061, %7059 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %7062, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_8647 = torch.constant.int 4 + %int1024_8648 = torch.constant.int 1024 + %7063 = torch.prim.ListConstruct %int4_8647, %298, %int1024_8648 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7064 = torch.aten.view %7062, %7063 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %7064, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_8649 = torch.constant.int 4 + %int32_8650 = torch.constant.int 32 + %int128_8651 = torch.constant.int 128 + %7065 = torch.prim.ListConstruct %int4_8649, %298, %int32_8650, %int128_8651 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7066 = torch.aten.view %7050, %7065 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7066, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_8652 = torch.constant.int 4 + %int8_8653 = torch.constant.int 8 + %int128_8654 = torch.constant.int 128 + %7067 = torch.prim.ListConstruct %int4_8652, %298, %int8_8653, %int128_8654 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7068 = torch.aten.view %7057, %7067 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7068, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_8655 = torch.constant.int 4 + %int8_8656 = torch.constant.int 8 + %int128_8657 = torch.constant.int 128 + %7069 = torch.prim.ListConstruct %int4_8655, %298, %int8_8656, %int128_8657 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7070 = torch.aten.view %7064, %7069 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7070, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_8658 = torch.constant.int 131072 + %none_8659 = torch.constant.none + %none_8660 = torch.constant.none + %cpu_8661 = torch.constant.device "cpu" + %false_8662 = torch.constant.bool false + %7071 = torch.aten.arange %int131072_8658, %none_8659, %none_8660, %cpu_8661, %false_8662 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_8663 = torch.constant.int 0 + %int128_8664 = torch.constant.int 128 + %int2_8665 = torch.constant.int 2 + %int4_8666 = torch.constant.int 4 + %none_8667 = torch.constant.none + %cpu_8668 = torch.constant.device "cpu" + %false_8669 = torch.constant.bool false + %7072 = torch.aten.arange.start_step %int0_8663, %int128_8664, %int2_8665, %int4_8666, %none_8667, %cpu_8668, %false_8669 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_8670 = torch.constant.int 6 + %7073 = torch.prims.convert_element_type %7072, %int6_8670 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_8671 = torch.constant.int 128 + %7074 = torch.aten.div.Scalar %7073, %int128_8671 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_8672 = torch.constant.float 5.000000e+05 + %7075 = torch.aten.pow.Scalar %float5.000000e05_8672, %7074 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7076 = torch.aten.reciprocal %7075 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_8673 = torch.constant.float 1.000000e+00 + %7077 = torch.aten.mul.Scalar %7076, %float1.000000e00_8673 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %7078 = torch.aten.reciprocal %7077 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_8674 = torch.constant.float 6.2831853071795862 + %7079 = torch.aten.mul.Scalar %7078, %float6.283190e00_8674 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_8675 = torch.constant.float 8.192000e+03 + %7080 = torch.aten.gt.Scalar %7079, %float8.192000e03_8675 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_8676 = torch.constant.int 8 + %7081 = torch.aten.div.Scalar %7077, %int8_8676 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7082 = torch.aten.where.self %7080, %7081, %7077 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7083 = torch.aten.reciprocal %7079 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_8677 = torch.constant.int 8192 + %7084 = torch.aten.mul.Scalar %7083, %int8192_8677 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_8678 = torch.constant.int 1 + %int1_8679 = torch.constant.int 1 + %7085 = torch.aten.sub.Scalar %7084, %int1_8678, %int1_8679 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_8680 = torch.constant.int 3 + %7086 = torch.aten.div.Scalar %7085, %int3_8680 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_8681 = torch.constant.int 1 + %int1_8682 = torch.constant.int 1 + %7087 = torch.aten.rsub.Scalar %7086, %int1_8681, %int1_8682 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %7088 = torch.aten.mul.Tensor %7087, %7082 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_8683 = torch.constant.int 8 + %7089 = torch.aten.div.Scalar %7088, %int8_8683 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7090 = torch.aten.mul.Tensor %7086, %7082 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_8684 = torch.constant.int 1 + %7091 = torch.aten.add.Tensor %7089, %7090, %int1_8684 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_8685 = torch.constant.float 2.048000e+03 + %7092 = torch.aten.lt.Scalar %7079, %float2.048000e03_8685 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7093 = torch.aten.bitwise_not %7092 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_8686 = torch.constant.float 8.192000e+03 + %7094 = torch.aten.gt.Scalar %7079, %float8.192000e03_8686 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7095 = torch.aten.bitwise_not %7094 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7096 = torch.aten.mul.Tensor %7093, %7095 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7097 = torch.aten.where.self %7096, %7091, %7082 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7098 = torch.prim.ListConstruct %7097, %7097 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_8687 = torch.constant.int -1 + %7099 = torch.aten.cat %7098, %int-1_8687 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_8688 = torch.constant.int 6 + %7100 = torch.prims.convert_element_type %7099, %int6_8688 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_8689 = torch.constant.int 1 + %7101 = torch.aten.unsqueeze %7071, %int1_8689 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_8690 = torch.constant.int 6 + %7102 = torch.prims.convert_element_type %7101, %int6_8690 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_8691 = torch.constant.int 0 + %7103 = torch.aten.unsqueeze %7100, %int0_8691 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_8692 = torch.constant.int 6 + %7104 = torch.prims.convert_element_type %7103, %int6_8692 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %7105 = torch.aten.mul.Tensor %7102, %7104 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %7106 = torch.aten.cos %7105 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_8693 = torch.constant.int 5 + %7107 = torch.prims.convert_element_type %7106, %int5_8693 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %7108 = torch.aten.sin %7105 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_8694 = torch.constant.int 5 + %7109 = torch.prims.convert_element_type %7108, %int5_8694 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_8695 = torch.constant.int 0 + %int0_8696 = torch.constant.int 0 + %int1_8697 = torch.constant.int 1 + %7110 = torch.aten.slice.Tensor %7107, %int0_8695, %int0_8696, %298, %int1_8697 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7110, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_8698 = torch.constant.int 1 + %int0_8699 = torch.constant.int 0 + %int9223372036854775807_8700 = torch.constant.int 9223372036854775807 + %int1_8701 = torch.constant.int 1 + %7111 = torch.aten.slice.Tensor %7110, %int1_8698, %int0_8699, %int9223372036854775807_8700, %int1_8701 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7111, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_8702 = torch.constant.int 0 + %int0_8703 = torch.constant.int 0 + %int1_8704 = torch.constant.int 1 + %7112 = torch.aten.slice.Tensor %7109, %int0_8702, %int0_8703, %298, %int1_8704 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7112, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_8705 = torch.constant.int 1 + %int0_8706 = torch.constant.int 0 + %int9223372036854775807_8707 = torch.constant.int 9223372036854775807 + %int1_8708 = torch.constant.int 1 + %7113 = torch.aten.slice.Tensor %7112, %int1_8705, %int0_8706, %int9223372036854775807_8707, %int1_8708 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7113, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_8709 = torch.constant.int 0 + %7114 = torch.aten.unsqueeze %7111, %int0_8709 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7114, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_8710 = torch.constant.int 1 + %int0_8711 = torch.constant.int 0 + %int9223372036854775807_8712 = torch.constant.int 9223372036854775807 + %int1_8713 = torch.constant.int 1 + %7115 = torch.aten.slice.Tensor %7114, %int1_8710, %int0_8711, %int9223372036854775807_8712, %int1_8713 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7115, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_8714 = torch.constant.int 2 + %7116 = torch.aten.unsqueeze %7115, %int2_8714 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7116, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_8715 = torch.constant.int 3 + %int0_8716 = torch.constant.int 0 + %int9223372036854775807_8717 = torch.constant.int 9223372036854775807 + %int1_8718 = torch.constant.int 1 + %7117 = torch.aten.slice.Tensor %7116, %int3_8715, %int0_8716, %int9223372036854775807_8717, %int1_8718 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7117, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_8719 = torch.constant.int 4 + %int1_8720 = torch.constant.int 1 + %int1_8721 = torch.constant.int 1 + %int1_8722 = torch.constant.int 1 + %7118 = torch.prim.ListConstruct %int4_8719, %int1_8720, %int1_8721, %int1_8722 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7119 = torch.aten.repeat %7117, %7118 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7119, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_8723 = torch.constant.int 0 + %7120 = torch.aten.unsqueeze %7113, %int0_8723 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7120, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_8724 = torch.constant.int 1 + %int0_8725 = torch.constant.int 0 + %int9223372036854775807_8726 = torch.constant.int 9223372036854775807 + %int1_8727 = torch.constant.int 1 + %7121 = torch.aten.slice.Tensor %7120, %int1_8724, %int0_8725, %int9223372036854775807_8726, %int1_8727 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7121, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_8728 = torch.constant.int 2 + %7122 = torch.aten.unsqueeze %7121, %int2_8728 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7122, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_8729 = torch.constant.int 3 + %int0_8730 = torch.constant.int 0 + %int9223372036854775807_8731 = torch.constant.int 9223372036854775807 + %int1_8732 = torch.constant.int 1 + %7123 = torch.aten.slice.Tensor %7122, %int3_8729, %int0_8730, %int9223372036854775807_8731, %int1_8732 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7123, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_8733 = torch.constant.int 4 + %int1_8734 = torch.constant.int 1 + %int1_8735 = torch.constant.int 1 + %int1_8736 = torch.constant.int 1 + %7124 = torch.prim.ListConstruct %int4_8733, %int1_8734, %int1_8735, %int1_8736 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7125 = torch.aten.repeat %7123, %7124 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7125, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %7126 = torch.aten.mul.Tensor %7066, %7119 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7126, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_8737 = torch.constant.int 3 + %int0_8738 = torch.constant.int 0 + %int64_8739 = torch.constant.int 64 + %int1_8740 = torch.constant.int 1 + %7127 = torch.aten.slice.Tensor %7066, %int3_8737, %int0_8738, %int64_8739, %int1_8740 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %7127, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_8741 = torch.constant.int 3 + %int64_8742 = torch.constant.int 64 + %int9223372036854775807_8743 = torch.constant.int 9223372036854775807 + %int1_8744 = torch.constant.int 1 + %7128 = torch.aten.slice.Tensor %7066, %int3_8741, %int64_8742, %int9223372036854775807_8743, %int1_8744 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %7128, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %7129 = torch.aten.neg %7128 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %7129, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %7130 = torch.prim.ListConstruct %7129, %7127 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_8745 = torch.constant.int -1 + %7131 = torch.aten.cat %7130, %int-1_8745 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7131, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %7132 = torch.aten.mul.Tensor %7131, %7125 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7132, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_8746 = torch.constant.int 1 + %7133 = torch.aten.add.Tensor %7126, %7132, %int1_8746 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7133, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_8747 = torch.constant.int 131072 + %none_8748 = torch.constant.none + %none_8749 = torch.constant.none + %cpu_8750 = torch.constant.device "cpu" + %false_8751 = torch.constant.bool false + %7134 = torch.aten.arange %int131072_8747, %none_8748, %none_8749, %cpu_8750, %false_8751 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_8752 = torch.constant.int 0 + %int128_8753 = torch.constant.int 128 + %int2_8754 = torch.constant.int 2 + %int4_8755 = torch.constant.int 4 + %none_8756 = torch.constant.none + %cpu_8757 = torch.constant.device "cpu" + %false_8758 = torch.constant.bool false + %7135 = torch.aten.arange.start_step %int0_8752, %int128_8753, %int2_8754, %int4_8755, %none_8756, %cpu_8757, %false_8758 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_8759 = torch.constant.int 6 + %7136 = torch.prims.convert_element_type %7135, %int6_8759 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_8760 = torch.constant.int 128 + %7137 = torch.aten.div.Scalar %7136, %int128_8760 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_8761 = torch.constant.float 5.000000e+05 + %7138 = torch.aten.pow.Scalar %float5.000000e05_8761, %7137 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7139 = torch.aten.reciprocal %7138 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_8762 = torch.constant.float 1.000000e+00 + %7140 = torch.aten.mul.Scalar %7139, %float1.000000e00_8762 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %7141 = torch.aten.reciprocal %7140 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_8763 = torch.constant.float 6.2831853071795862 + %7142 = torch.aten.mul.Scalar %7141, %float6.283190e00_8763 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_8764 = torch.constant.float 8.192000e+03 + %7143 = torch.aten.gt.Scalar %7142, %float8.192000e03_8764 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_8765 = torch.constant.int 8 + %7144 = torch.aten.div.Scalar %7140, %int8_8765 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7145 = torch.aten.where.self %7143, %7144, %7140 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7146 = torch.aten.reciprocal %7142 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_8766 = torch.constant.int 8192 + %7147 = torch.aten.mul.Scalar %7146, %int8192_8766 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_8767 = torch.constant.int 1 + %int1_8768 = torch.constant.int 1 + %7148 = torch.aten.sub.Scalar %7147, %int1_8767, %int1_8768 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_8769 = torch.constant.int 3 + %7149 = torch.aten.div.Scalar %7148, %int3_8769 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_8770 = torch.constant.int 1 + %int1_8771 = torch.constant.int 1 + %7150 = torch.aten.rsub.Scalar %7149, %int1_8770, %int1_8771 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %7151 = torch.aten.mul.Tensor %7150, %7145 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_8772 = torch.constant.int 8 + %7152 = torch.aten.div.Scalar %7151, %int8_8772 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7153 = torch.aten.mul.Tensor %7149, %7145 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_8773 = torch.constant.int 1 + %7154 = torch.aten.add.Tensor %7152, %7153, %int1_8773 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_8774 = torch.constant.float 2.048000e+03 + %7155 = torch.aten.lt.Scalar %7142, %float2.048000e03_8774 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7156 = torch.aten.bitwise_not %7155 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_8775 = torch.constant.float 8.192000e+03 + %7157 = torch.aten.gt.Scalar %7142, %float8.192000e03_8775 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7158 = torch.aten.bitwise_not %7157 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7159 = torch.aten.mul.Tensor %7156, %7158 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7160 = torch.aten.where.self %7159, %7154, %7145 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7161 = torch.prim.ListConstruct %7160, %7160 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_8776 = torch.constant.int -1 + %7162 = torch.aten.cat %7161, %int-1_8776 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_8777 = torch.constant.int 6 + %7163 = torch.prims.convert_element_type %7162, %int6_8777 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_8778 = torch.constant.int 1 + %7164 = torch.aten.unsqueeze %7134, %int1_8778 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_8779 = torch.constant.int 6 + %7165 = torch.prims.convert_element_type %7164, %int6_8779 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_8780 = torch.constant.int 0 + %7166 = torch.aten.unsqueeze %7163, %int0_8780 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_8781 = torch.constant.int 6 + %7167 = torch.prims.convert_element_type %7166, %int6_8781 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %7168 = torch.aten.mul.Tensor %7165, %7167 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %7169 = torch.aten.cos %7168 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_8782 = torch.constant.int 5 + %7170 = torch.prims.convert_element_type %7169, %int5_8782 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %7171 = torch.aten.sin %7168 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_8783 = torch.constant.int 5 + %7172 = torch.prims.convert_element_type %7171, %int5_8783 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_8784 = torch.constant.int 0 + %int0_8785 = torch.constant.int 0 + %int1_8786 = torch.constant.int 1 + %7173 = torch.aten.slice.Tensor %7170, %int0_8784, %int0_8785, %298, %int1_8786 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7173, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_8787 = torch.constant.int 1 + %int0_8788 = torch.constant.int 0 + %int9223372036854775807_8789 = torch.constant.int 9223372036854775807 + %int1_8790 = torch.constant.int 1 + %7174 = torch.aten.slice.Tensor %7173, %int1_8787, %int0_8788, %int9223372036854775807_8789, %int1_8790 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7174, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_8791 = torch.constant.int 0 + %int0_8792 = torch.constant.int 0 + %int1_8793 = torch.constant.int 1 + %7175 = torch.aten.slice.Tensor %7172, %int0_8791, %int0_8792, %298, %int1_8793 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7175, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_8794 = torch.constant.int 1 + %int0_8795 = torch.constant.int 0 + %int9223372036854775807_8796 = torch.constant.int 9223372036854775807 + %int1_8797 = torch.constant.int 1 + %7176 = torch.aten.slice.Tensor %7175, %int1_8794, %int0_8795, %int9223372036854775807_8796, %int1_8797 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7176, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_8798 = torch.constant.int 0 + %7177 = torch.aten.unsqueeze %7174, %int0_8798 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7177, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_8799 = torch.constant.int 1 + %int0_8800 = torch.constant.int 0 + %int9223372036854775807_8801 = torch.constant.int 9223372036854775807 + %int1_8802 = torch.constant.int 1 + %7178 = torch.aten.slice.Tensor %7177, %int1_8799, %int0_8800, %int9223372036854775807_8801, %int1_8802 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7178, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_8803 = torch.constant.int 2 + %7179 = torch.aten.unsqueeze %7178, %int2_8803 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7179, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_8804 = torch.constant.int 3 + %int0_8805 = torch.constant.int 0 + %int9223372036854775807_8806 = torch.constant.int 9223372036854775807 + %int1_8807 = torch.constant.int 1 + %7180 = torch.aten.slice.Tensor %7179, %int3_8804, %int0_8805, %int9223372036854775807_8806, %int1_8807 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7180, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_8808 = torch.constant.int 4 + %int1_8809 = torch.constant.int 1 + %int1_8810 = torch.constant.int 1 + %int1_8811 = torch.constant.int 1 + %7181 = torch.prim.ListConstruct %int4_8808, %int1_8809, %int1_8810, %int1_8811 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7182 = torch.aten.repeat %7180, %7181 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7182, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_8812 = torch.constant.int 0 + %7183 = torch.aten.unsqueeze %7176, %int0_8812 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7183, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_8813 = torch.constant.int 1 + %int0_8814 = torch.constant.int 0 + %int9223372036854775807_8815 = torch.constant.int 9223372036854775807 + %int1_8816 = torch.constant.int 1 + %7184 = torch.aten.slice.Tensor %7183, %int1_8813, %int0_8814, %int9223372036854775807_8815, %int1_8816 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7184, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_8817 = torch.constant.int 2 + %7185 = torch.aten.unsqueeze %7184, %int2_8817 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7185, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_8818 = torch.constant.int 3 + %int0_8819 = torch.constant.int 0 + %int9223372036854775807_8820 = torch.constant.int 9223372036854775807 + %int1_8821 = torch.constant.int 1 + %7186 = torch.aten.slice.Tensor %7185, %int3_8818, %int0_8819, %int9223372036854775807_8820, %int1_8821 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7186, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_8822 = torch.constant.int 4 + %int1_8823 = torch.constant.int 1 + %int1_8824 = torch.constant.int 1 + %int1_8825 = torch.constant.int 1 + %7187 = torch.prim.ListConstruct %int4_8822, %int1_8823, %int1_8824, %int1_8825 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7188 = torch.aten.repeat %7186, %7187 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7188, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %7189 = torch.aten.mul.Tensor %7068, %7182 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7189, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_8826 = torch.constant.int 3 + %int0_8827 = torch.constant.int 0 + %int64_8828 = torch.constant.int 64 + %int1_8829 = torch.constant.int 1 + %7190 = torch.aten.slice.Tensor %7068, %int3_8826, %int0_8827, %int64_8828, %int1_8829 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %7190, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_8830 = torch.constant.int 3 + %int64_8831 = torch.constant.int 64 + %int9223372036854775807_8832 = torch.constant.int 9223372036854775807 + %int1_8833 = torch.constant.int 1 + %7191 = torch.aten.slice.Tensor %7068, %int3_8830, %int64_8831, %int9223372036854775807_8832, %int1_8833 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %7191, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %7192 = torch.aten.neg %7191 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %7192, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %7193 = torch.prim.ListConstruct %7192, %7190 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_8834 = torch.constant.int -1 + %7194 = torch.aten.cat %7193, %int-1_8834 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7194, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %7195 = torch.aten.mul.Tensor %7194, %7188 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7195, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_8835 = torch.constant.int 1 + %7196 = torch.aten.add.Tensor %7189, %7195, %int1_8835 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7196, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_8836 = torch.constant.int 32 + %7197 = torch.aten.mul.Scalar %arg2, %int32_8836 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7197, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int25 = torch.constant.int 25 + %int1_8837 = torch.constant.int 1 + %7198 = torch.aten.add.Scalar %7197, %int25, %int1_8837 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7198, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_8838 = torch.constant.int 2 + %7199 = torch.aten.mul.Scalar %7198, %int2_8838 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7199, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_8839 = torch.constant.int 0 + %int1_8840 = torch.constant.int 1 + %7200 = torch.aten.add.Scalar %7199, %int0_8839, %int1_8840 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7200, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %7201 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %7202 = torch.aten.view %7200, %7201 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %7202, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_8841 = torch.constant.int 4 + %int32_8842 = torch.constant.int 32 + %int8_8843 = torch.constant.int 8 + %int128_8844 = torch.constant.int 128 + %7203 = torch.prim.ListConstruct %int4_8841, %296, %int32_8842, %int8_8843, %int128_8844 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7204 = torch.aten.view %7196, %7203 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %7204, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_8845 = torch.constant.int 32 + %int8_8846 = torch.constant.int 8 + %int128_8847 = torch.constant.int 128 + %7205 = torch.prim.ListConstruct %504, %int32_8845, %int8_8846, %int128_8847 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7206 = torch.aten.view %7204, %7205 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %7206, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_8848 = torch.constant.int 1 + %int2_8849 = torch.constant.int 2 + %7207 = torch.aten.transpose.int %7206, %int1_8848, %int2_8849 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7207, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_8850 = torch.constant.int 5 + %7208 = torch.prims.convert_element_type %7207, %int5_8850 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7208, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_8851 = torch.constant.int 32 + %int2_8852 = torch.constant.int 2 + %int8_8853 = torch.constant.int 8 + %int32_8854 = torch.constant.int 32 + %int128_8855 = torch.constant.int 128 + %7209 = torch.prim.ListConstruct %297, %int32_8851, %int2_8852, %int8_8853, %int32_8854, %int128_8855 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7210 = torch.aten.view %6972, %7209 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7210, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_8856 = torch.constant.int 8 + %int32_8857 = torch.constant.int 32 + %int128_8858 = torch.constant.int 128 + %7211 = torch.prim.ListConstruct %497, %int8_8856, %int32_8857, %int128_8858 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7212 = torch.aten.view %7210, %7211 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7212, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %7213 = torch.prim.ListConstruct %7202 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_8859 = torch.constant.bool false + %7214 = torch.aten.index_put %7212, %7213, %7208, %false_8859 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7214, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_8860 = torch.constant.int 32 + %int2_8861 = torch.constant.int 2 + %int8_8862 = torch.constant.int 8 + %int32_8863 = torch.constant.int 32 + %int128_8864 = torch.constant.int 128 + %7215 = torch.prim.ListConstruct %297, %int32_8860, %int2_8861, %int8_8862, %int32_8863, %int128_8864 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7216 = torch.aten.view %7214, %7215 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7216, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_8865 = torch.constant.int 2097152 + %7217 = torch.prim.ListConstruct %297, %int2097152_8865 : (!torch.int, !torch.int) -> !torch.list + %7218 = torch.aten.view %7216, %7217 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %7218, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_8866 = torch.constant.int 32 + %int2_8867 = torch.constant.int 2 + %int8_8868 = torch.constant.int 8 + %int32_8869 = torch.constant.int 32 + %int128_8870 = torch.constant.int 128 + %7219 = torch.prim.ListConstruct %297, %int32_8866, %int2_8867, %int8_8868, %int32_8869, %int128_8870 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7220 = torch.aten.view %7218, %7219 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7220, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_8871 = torch.constant.int 8 + %int32_8872 = torch.constant.int 32 + %int128_8873 = torch.constant.int 128 + %7221 = torch.prim.ListConstruct %497, %int8_8871, %int32_8872, %int128_8873 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7222 = torch.aten.view %7220, %7221 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7222, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_8874 = torch.constant.int 32 + %7223 = torch.aten.mul.Scalar %arg2, %int32_8874 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7223, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int25_8875 = torch.constant.int 25 + %int1_8876 = torch.constant.int 1 + %7224 = torch.aten.add.Scalar %7223, %int25_8875, %int1_8876 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7224, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_8877 = torch.constant.int 2 + %7225 = torch.aten.mul.Scalar %7224, %int2_8877 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7225, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_8878 = torch.constant.int 1 + %int1_8879 = torch.constant.int 1 + %7226 = torch.aten.add.Scalar %7225, %int1_8878, %int1_8879 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7226, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %7227 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %7228 = torch.aten.view %7226, %7227 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %7228, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_8880 = torch.constant.int 4 + %int32_8881 = torch.constant.int 32 + %int8_8882 = torch.constant.int 8 + %int128_8883 = torch.constant.int 128 + %7229 = torch.prim.ListConstruct %int4_8880, %296, %int32_8881, %int8_8882, %int128_8883 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7230 = torch.aten.view %7070, %7229 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %7230, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_8884 = torch.constant.int 32 + %int8_8885 = torch.constant.int 8 + %int128_8886 = torch.constant.int 128 + %7231 = torch.prim.ListConstruct %504, %int32_8884, %int8_8885, %int128_8886 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7232 = torch.aten.view %7230, %7231 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %7232, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_8887 = torch.constant.int 1 + %int2_8888 = torch.constant.int 2 + %7233 = torch.aten.transpose.int %7232, %int1_8887, %int2_8888 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7233, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_8889 = torch.constant.int 5 + %7234 = torch.prims.convert_element_type %7233, %int5_8889 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7234, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %7235 = torch.prim.ListConstruct %7228 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_8890 = torch.constant.bool false + %7236 = torch.aten.index_put %7222, %7235, %7234, %false_8890 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7236, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_8891 = torch.constant.int 32 + %int2_8892 = torch.constant.int 2 + %int8_8893 = torch.constant.int 8 + %int32_8894 = torch.constant.int 32 + %int128_8895 = torch.constant.int 128 + %7237 = torch.prim.ListConstruct %297, %int32_8891, %int2_8892, %int8_8893, %int32_8894, %int128_8895 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7238 = torch.aten.view %7236, %7237 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7238, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_8896 = torch.constant.int 2097152 + %7239 = torch.prim.ListConstruct %297, %int2097152_8896 : (!torch.int, !torch.int) -> !torch.list + %7240 = torch.aten.view %7238, %7239 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %7240, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_8897 = torch.constant.int -2 + %7241 = torch.aten.unsqueeze %7196, %int-2_8897 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %7241, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_8898 = torch.constant.int 4 + %int8_8899 = torch.constant.int 8 + %int4_8900 = torch.constant.int 4 + %int128_8901 = torch.constant.int 128 + %7242 = torch.prim.ListConstruct %int4_8898, %298, %int8_8899, %int4_8900, %int128_8901 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_8902 = torch.constant.bool false + %7243 = torch.aten.expand %7241, %7242, %false_8902 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7243, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_8903 = torch.constant.int 0 + %7244 = torch.aten.clone %7243, %int0_8903 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7244, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_8904 = torch.constant.int 4 + %int32_8905 = torch.constant.int 32 + %int128_8906 = torch.constant.int 128 + %7245 = torch.prim.ListConstruct %int4_8904, %298, %int32_8905, %int128_8906 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7246 = torch.aten._unsafe_view %7244, %7245 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7246, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_8907 = torch.constant.int -2 + %7247 = torch.aten.unsqueeze %7070, %int-2_8907 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %7247, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_8908 = torch.constant.int 4 + %int8_8909 = torch.constant.int 8 + %int4_8910 = torch.constant.int 4 + %int128_8911 = torch.constant.int 128 + %7248 = torch.prim.ListConstruct %int4_8908, %298, %int8_8909, %int4_8910, %int128_8911 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_8912 = torch.constant.bool false + %7249 = torch.aten.expand %7247, %7248, %false_8912 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7249, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_8913 = torch.constant.int 0 + %7250 = torch.aten.clone %7249, %int0_8913 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7250, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_8914 = torch.constant.int 4 + %int32_8915 = torch.constant.int 32 + %int128_8916 = torch.constant.int 128 + %7251 = torch.prim.ListConstruct %int4_8914, %298, %int32_8915, %int128_8916 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7252 = torch.aten._unsafe_view %7250, %7251 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7252, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_8917 = torch.constant.int 1 + %int2_8918 = torch.constant.int 2 + %7253 = torch.aten.transpose.int %7133, %int1_8917, %int2_8918 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %7253, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_8919 = torch.constant.int 1 + %int2_8920 = torch.constant.int 2 + %7254 = torch.aten.transpose.int %7246, %int1_8919, %int2_8920 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %7254, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_8921 = torch.constant.int 1 + %int2_8922 = torch.constant.int 2 + %7255 = torch.aten.transpose.int %7252, %int1_8921, %int2_8922 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %7255, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_8923 = torch.constant.float 0.000000e+00 + %false_8924 = torch.constant.bool false + %none_8925 = torch.constant.none + %7256:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%7253, %7254, %7255, %float0.000000e00_8923, %false_8924, %327, %none_8925) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %7256#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_8926 = torch.constant.int 1 + %int2_8927 = torch.constant.int 2 + %7257 = torch.aten.transpose.int %7256#0, %int1_8926, %int2_8927 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7257, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_8928 = torch.constant.int 4 + %int4096_8929 = torch.constant.int 4096 + %7258 = torch.prim.ListConstruct %int4_8928, %298, %int4096_8929 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7259 = torch.aten.view %7257, %7258 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7259, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_8930 = torch.constant.int -2 + %int-1_8931 = torch.constant.int -1 + %7260 = torch.aten.transpose.int %231, %int-2_8930, %int-1_8931 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_8932 = torch.constant.int 5 + %7261 = torch.prims.convert_element_type %7260, %int5_8932 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_8933 = torch.constant.int 4096 + %7262 = torch.prim.ListConstruct %342, %int4096_8933 : (!torch.int, !torch.int) -> !torch.list + %7263 = torch.aten.view %7259, %7262 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7263, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7264 = torch.aten.mm %7263, %7261 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7264, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_8934 = torch.constant.int 4 + %int4096_8935 = torch.constant.int 4096 + %7265 = torch.prim.ListConstruct %int4_8934, %298, %int4096_8935 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7266 = torch.aten.view %7264, %7265 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7266, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_8936 = torch.constant.int 1 + %7267 = torch.aten.add.Tensor %7033, %7266, %int1_8936 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7267, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_8937 = torch.constant.int 6 + %7268 = torch.prims.convert_element_type %7267, %int6_8937 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7268, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_8938 = torch.constant.int 2 + %7269 = torch.aten.pow.Tensor_Scalar %7268, %int2_8938 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7269, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_8939 = torch.constant.int -1 + %7270 = torch.prim.ListConstruct %int-1_8939 : (!torch.int) -> !torch.list + %true_8940 = torch.constant.bool true + %none_8941 = torch.constant.none + %7271 = torch.aten.mean.dim %7269, %7270, %true_8940, %none_8941 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7271, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_8942 = torch.constant.float 9.9999997473787516E-6 + %int1_8943 = torch.constant.int 1 + %7272 = torch.aten.add.Scalar %7271, %float9.999990e-06_8942, %int1_8943 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7272, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7273 = torch.aten.rsqrt %7272 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7273, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7274 = torch.aten.mul.Tensor %7268, %7273 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7274, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_8944 = torch.constant.int 5 + %7275 = torch.prims.convert_element_type %7274, %int5_8944 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7275, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %7276 = torch.aten.mul.Tensor %232, %7275 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7276, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_8945 = torch.constant.int 5 + %7277 = torch.prims.convert_element_type %7276, %int5_8945 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7277, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_8946 = torch.constant.int -2 + %int-1_8947 = torch.constant.int -1 + %7278 = torch.aten.transpose.int %233, %int-2_8946, %int-1_8947 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_8948 = torch.constant.int 5 + %7279 = torch.prims.convert_element_type %7278, %int5_8948 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_8949 = torch.constant.int 4096 + %7280 = torch.prim.ListConstruct %342, %int4096_8949 : (!torch.int, !torch.int) -> !torch.list + %7281 = torch.aten.view %7277, %7280 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7281, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7282 = torch.aten.mm %7281, %7279 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %7282, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_8950 = torch.constant.int 4 + %int14336_8951 = torch.constant.int 14336 + %7283 = torch.prim.ListConstruct %int4_8950, %298, %int14336_8951 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7284 = torch.aten.view %7282, %7283 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7284, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %7285 = torch.aten.silu %7284 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7285, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_8952 = torch.constant.int -2 + %int-1_8953 = torch.constant.int -1 + %7286 = torch.aten.transpose.int %234, %int-2_8952, %int-1_8953 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_8954 = torch.constant.int 5 + %7287 = torch.prims.convert_element_type %7286, %int5_8954 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_8955 = torch.constant.int 4096 + %7288 = torch.prim.ListConstruct %342, %int4096_8955 : (!torch.int, !torch.int) -> !torch.list + %7289 = torch.aten.view %7277, %7288 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7289, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7290 = torch.aten.mm %7289, %7287 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %7290, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_8956 = torch.constant.int 4 + %int14336_8957 = torch.constant.int 14336 + %7291 = torch.prim.ListConstruct %int4_8956, %298, %int14336_8957 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7292 = torch.aten.view %7290, %7291 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7292, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %7293 = torch.aten.mul.Tensor %7285, %7292 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7293, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_8958 = torch.constant.int -2 + %int-1_8959 = torch.constant.int -1 + %7294 = torch.aten.transpose.int %235, %int-2_8958, %int-1_8959 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_8960 = torch.constant.int 5 + %7295 = torch.prims.convert_element_type %7294, %int5_8960 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_8961 = torch.constant.int 14336 + %7296 = torch.prim.ListConstruct %342, %int14336_8961 : (!torch.int, !torch.int) -> !torch.list + %7297 = torch.aten.view %7293, %7296 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %7297, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %7298 = torch.aten.mm %7297, %7295 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7298, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_8962 = torch.constant.int 4 + %int4096_8963 = torch.constant.int 4096 + %7299 = torch.prim.ListConstruct %int4_8962, %298, %int4096_8963 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7300 = torch.aten.view %7298, %7299 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7300, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_8964 = torch.constant.int 1 + %7301 = torch.aten.add.Tensor %7267, %7300, %int1_8964 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7301, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_8965 = torch.constant.int 6 + %7302 = torch.prims.convert_element_type %7301, %int6_8965 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7302, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_8966 = torch.constant.int 2 + %7303 = torch.aten.pow.Tensor_Scalar %7302, %int2_8966 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7303, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_8967 = torch.constant.int -1 + %7304 = torch.prim.ListConstruct %int-1_8967 : (!torch.int) -> !torch.list + %true_8968 = torch.constant.bool true + %none_8969 = torch.constant.none + %7305 = torch.aten.mean.dim %7303, %7304, %true_8968, %none_8969 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7305, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_8970 = torch.constant.float 9.9999997473787516E-6 + %int1_8971 = torch.constant.int 1 + %7306 = torch.aten.add.Scalar %7305, %float9.999990e-06_8970, %int1_8971 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7306, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7307 = torch.aten.rsqrt %7306 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7307, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7308 = torch.aten.mul.Tensor %7302, %7307 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7308, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_8972 = torch.constant.int 5 + %7309 = torch.prims.convert_element_type %7308, %int5_8972 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7309, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %7310 = torch.aten.mul.Tensor %236, %7309 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7310, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_8973 = torch.constant.int 5 + %7311 = torch.prims.convert_element_type %7310, %int5_8973 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7311, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_8974 = torch.constant.int -2 + %int-1_8975 = torch.constant.int -1 + %7312 = torch.aten.transpose.int %237, %int-2_8974, %int-1_8975 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_8976 = torch.constant.int 5 + %7313 = torch.prims.convert_element_type %7312, %int5_8976 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_8977 = torch.constant.int 4096 + %7314 = torch.prim.ListConstruct %342, %int4096_8977 : (!torch.int, !torch.int) -> !torch.list + %7315 = torch.aten.view %7311, %7314 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7315, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7316 = torch.aten.mm %7315, %7313 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7316, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_8978 = torch.constant.int 4 + %int4096_8979 = torch.constant.int 4096 + %7317 = torch.prim.ListConstruct %int4_8978, %298, %int4096_8979 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7318 = torch.aten.view %7316, %7317 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7318, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_8980 = torch.constant.int -2 + %int-1_8981 = torch.constant.int -1 + %7319 = torch.aten.transpose.int %238, %int-2_8980, %int-1_8981 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_8982 = torch.constant.int 5 + %7320 = torch.prims.convert_element_type %7319, %int5_8982 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_8983 = torch.constant.int 4096 + %7321 = torch.prim.ListConstruct %342, %int4096_8983 : (!torch.int, !torch.int) -> !torch.list + %7322 = torch.aten.view %7311, %7321 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7322, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7323 = torch.aten.mm %7322, %7320 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %7323, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_8984 = torch.constant.int 4 + %int1024_8985 = torch.constant.int 1024 + %7324 = torch.prim.ListConstruct %int4_8984, %298, %int1024_8985 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7325 = torch.aten.view %7323, %7324 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %7325, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_8986 = torch.constant.int -2 + %int-1_8987 = torch.constant.int -1 + %7326 = torch.aten.transpose.int %239, %int-2_8986, %int-1_8987 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_8988 = torch.constant.int 5 + %7327 = torch.prims.convert_element_type %7326, %int5_8988 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_8989 = torch.constant.int 4096 + %7328 = torch.prim.ListConstruct %342, %int4096_8989 : (!torch.int, !torch.int) -> !torch.list + %7329 = torch.aten.view %7311, %7328 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7329, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7330 = torch.aten.mm %7329, %7327 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %7330, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_8990 = torch.constant.int 4 + %int1024_8991 = torch.constant.int 1024 + %7331 = torch.prim.ListConstruct %int4_8990, %298, %int1024_8991 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7332 = torch.aten.view %7330, %7331 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %7332, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_8992 = torch.constant.int 4 + %int32_8993 = torch.constant.int 32 + %int128_8994 = torch.constant.int 128 + %7333 = torch.prim.ListConstruct %int4_8992, %298, %int32_8993, %int128_8994 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7334 = torch.aten.view %7318, %7333 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7334, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_8995 = torch.constant.int 4 + %int8_8996 = torch.constant.int 8 + %int128_8997 = torch.constant.int 128 + %7335 = torch.prim.ListConstruct %int4_8995, %298, %int8_8996, %int128_8997 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7336 = torch.aten.view %7325, %7335 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7336, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_8998 = torch.constant.int 4 + %int8_8999 = torch.constant.int 8 + %int128_9000 = torch.constant.int 128 + %7337 = torch.prim.ListConstruct %int4_8998, %298, %int8_8999, %int128_9000 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7338 = torch.aten.view %7332, %7337 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7338, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_9001 = torch.constant.int 131072 + %none_9002 = torch.constant.none + %none_9003 = torch.constant.none + %cpu_9004 = torch.constant.device "cpu" + %false_9005 = torch.constant.bool false + %7339 = torch.aten.arange %int131072_9001, %none_9002, %none_9003, %cpu_9004, %false_9005 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_9006 = torch.constant.int 0 + %int128_9007 = torch.constant.int 128 + %int2_9008 = torch.constant.int 2 + %int4_9009 = torch.constant.int 4 + %none_9010 = torch.constant.none + %cpu_9011 = torch.constant.device "cpu" + %false_9012 = torch.constant.bool false + %7340 = torch.aten.arange.start_step %int0_9006, %int128_9007, %int2_9008, %int4_9009, %none_9010, %cpu_9011, %false_9012 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_9013 = torch.constant.int 6 + %7341 = torch.prims.convert_element_type %7340, %int6_9013 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_9014 = torch.constant.int 128 + %7342 = torch.aten.div.Scalar %7341, %int128_9014 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_9015 = torch.constant.float 5.000000e+05 + %7343 = torch.aten.pow.Scalar %float5.000000e05_9015, %7342 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7344 = torch.aten.reciprocal %7343 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_9016 = torch.constant.float 1.000000e+00 + %7345 = torch.aten.mul.Scalar %7344, %float1.000000e00_9016 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %7346 = torch.aten.reciprocal %7345 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_9017 = torch.constant.float 6.2831853071795862 + %7347 = torch.aten.mul.Scalar %7346, %float6.283190e00_9017 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_9018 = torch.constant.float 8.192000e+03 + %7348 = torch.aten.gt.Scalar %7347, %float8.192000e03_9018 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_9019 = torch.constant.int 8 + %7349 = torch.aten.div.Scalar %7345, %int8_9019 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7350 = torch.aten.where.self %7348, %7349, %7345 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7351 = torch.aten.reciprocal %7347 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_9020 = torch.constant.int 8192 + %7352 = torch.aten.mul.Scalar %7351, %int8192_9020 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_9021 = torch.constant.int 1 + %int1_9022 = torch.constant.int 1 + %7353 = torch.aten.sub.Scalar %7352, %int1_9021, %int1_9022 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_9023 = torch.constant.int 3 + %7354 = torch.aten.div.Scalar %7353, %int3_9023 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_9024 = torch.constant.int 1 + %int1_9025 = torch.constant.int 1 + %7355 = torch.aten.rsub.Scalar %7354, %int1_9024, %int1_9025 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %7356 = torch.aten.mul.Tensor %7355, %7350 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_9026 = torch.constant.int 8 + %7357 = torch.aten.div.Scalar %7356, %int8_9026 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7358 = torch.aten.mul.Tensor %7354, %7350 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_9027 = torch.constant.int 1 + %7359 = torch.aten.add.Tensor %7357, %7358, %int1_9027 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_9028 = torch.constant.float 2.048000e+03 + %7360 = torch.aten.lt.Scalar %7347, %float2.048000e03_9028 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7361 = torch.aten.bitwise_not %7360 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_9029 = torch.constant.float 8.192000e+03 + %7362 = torch.aten.gt.Scalar %7347, %float8.192000e03_9029 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7363 = torch.aten.bitwise_not %7362 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7364 = torch.aten.mul.Tensor %7361, %7363 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7365 = torch.aten.where.self %7364, %7359, %7350 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7366 = torch.prim.ListConstruct %7365, %7365 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_9030 = torch.constant.int -1 + %7367 = torch.aten.cat %7366, %int-1_9030 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_9031 = torch.constant.int 6 + %7368 = torch.prims.convert_element_type %7367, %int6_9031 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_9032 = torch.constant.int 1 + %7369 = torch.aten.unsqueeze %7339, %int1_9032 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_9033 = torch.constant.int 6 + %7370 = torch.prims.convert_element_type %7369, %int6_9033 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_9034 = torch.constant.int 0 + %7371 = torch.aten.unsqueeze %7368, %int0_9034 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_9035 = torch.constant.int 6 + %7372 = torch.prims.convert_element_type %7371, %int6_9035 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %7373 = torch.aten.mul.Tensor %7370, %7372 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %7374 = torch.aten.cos %7373 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_9036 = torch.constant.int 5 + %7375 = torch.prims.convert_element_type %7374, %int5_9036 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %7376 = torch.aten.sin %7373 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_9037 = torch.constant.int 5 + %7377 = torch.prims.convert_element_type %7376, %int5_9037 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_9038 = torch.constant.int 0 + %int0_9039 = torch.constant.int 0 + %int1_9040 = torch.constant.int 1 + %7378 = torch.aten.slice.Tensor %7375, %int0_9038, %int0_9039, %298, %int1_9040 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7378, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_9041 = torch.constant.int 1 + %int0_9042 = torch.constant.int 0 + %int9223372036854775807_9043 = torch.constant.int 9223372036854775807 + %int1_9044 = torch.constant.int 1 + %7379 = torch.aten.slice.Tensor %7378, %int1_9041, %int0_9042, %int9223372036854775807_9043, %int1_9044 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7379, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_9045 = torch.constant.int 0 + %int0_9046 = torch.constant.int 0 + %int1_9047 = torch.constant.int 1 + %7380 = torch.aten.slice.Tensor %7377, %int0_9045, %int0_9046, %298, %int1_9047 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7380, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_9048 = torch.constant.int 1 + %int0_9049 = torch.constant.int 0 + %int9223372036854775807_9050 = torch.constant.int 9223372036854775807 + %int1_9051 = torch.constant.int 1 + %7381 = torch.aten.slice.Tensor %7380, %int1_9048, %int0_9049, %int9223372036854775807_9050, %int1_9051 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7381, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_9052 = torch.constant.int 0 + %7382 = torch.aten.unsqueeze %7379, %int0_9052 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7382, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_9053 = torch.constant.int 1 + %int0_9054 = torch.constant.int 0 + %int9223372036854775807_9055 = torch.constant.int 9223372036854775807 + %int1_9056 = torch.constant.int 1 + %7383 = torch.aten.slice.Tensor %7382, %int1_9053, %int0_9054, %int9223372036854775807_9055, %int1_9056 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7383, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_9057 = torch.constant.int 2 + %7384 = torch.aten.unsqueeze %7383, %int2_9057 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7384, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_9058 = torch.constant.int 3 + %int0_9059 = torch.constant.int 0 + %int9223372036854775807_9060 = torch.constant.int 9223372036854775807 + %int1_9061 = torch.constant.int 1 + %7385 = torch.aten.slice.Tensor %7384, %int3_9058, %int0_9059, %int9223372036854775807_9060, %int1_9061 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7385, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_9062 = torch.constant.int 4 + %int1_9063 = torch.constant.int 1 + %int1_9064 = torch.constant.int 1 + %int1_9065 = torch.constant.int 1 + %7386 = torch.prim.ListConstruct %int4_9062, %int1_9063, %int1_9064, %int1_9065 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7387 = torch.aten.repeat %7385, %7386 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7387, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_9066 = torch.constant.int 0 + %7388 = torch.aten.unsqueeze %7381, %int0_9066 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7388, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_9067 = torch.constant.int 1 + %int0_9068 = torch.constant.int 0 + %int9223372036854775807_9069 = torch.constant.int 9223372036854775807 + %int1_9070 = torch.constant.int 1 + %7389 = torch.aten.slice.Tensor %7388, %int1_9067, %int0_9068, %int9223372036854775807_9069, %int1_9070 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7389, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_9071 = torch.constant.int 2 + %7390 = torch.aten.unsqueeze %7389, %int2_9071 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7390, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_9072 = torch.constant.int 3 + %int0_9073 = torch.constant.int 0 + %int9223372036854775807_9074 = torch.constant.int 9223372036854775807 + %int1_9075 = torch.constant.int 1 + %7391 = torch.aten.slice.Tensor %7390, %int3_9072, %int0_9073, %int9223372036854775807_9074, %int1_9075 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7391, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_9076 = torch.constant.int 4 + %int1_9077 = torch.constant.int 1 + %int1_9078 = torch.constant.int 1 + %int1_9079 = torch.constant.int 1 + %7392 = torch.prim.ListConstruct %int4_9076, %int1_9077, %int1_9078, %int1_9079 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7393 = torch.aten.repeat %7391, %7392 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7393, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %7394 = torch.aten.mul.Tensor %7334, %7387 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7394, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_9080 = torch.constant.int 3 + %int0_9081 = torch.constant.int 0 + %int64_9082 = torch.constant.int 64 + %int1_9083 = torch.constant.int 1 + %7395 = torch.aten.slice.Tensor %7334, %int3_9080, %int0_9081, %int64_9082, %int1_9083 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %7395, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_9084 = torch.constant.int 3 + %int64_9085 = torch.constant.int 64 + %int9223372036854775807_9086 = torch.constant.int 9223372036854775807 + %int1_9087 = torch.constant.int 1 + %7396 = torch.aten.slice.Tensor %7334, %int3_9084, %int64_9085, %int9223372036854775807_9086, %int1_9087 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %7396, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %7397 = torch.aten.neg %7396 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %7397, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %7398 = torch.prim.ListConstruct %7397, %7395 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_9088 = torch.constant.int -1 + %7399 = torch.aten.cat %7398, %int-1_9088 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7399, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %7400 = torch.aten.mul.Tensor %7399, %7393 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7400, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_9089 = torch.constant.int 1 + %7401 = torch.aten.add.Tensor %7394, %7400, %int1_9089 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7401, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_9090 = torch.constant.int 131072 + %none_9091 = torch.constant.none + %none_9092 = torch.constant.none + %cpu_9093 = torch.constant.device "cpu" + %false_9094 = torch.constant.bool false + %7402 = torch.aten.arange %int131072_9090, %none_9091, %none_9092, %cpu_9093, %false_9094 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_9095 = torch.constant.int 0 + %int128_9096 = torch.constant.int 128 + %int2_9097 = torch.constant.int 2 + %int4_9098 = torch.constant.int 4 + %none_9099 = torch.constant.none + %cpu_9100 = torch.constant.device "cpu" + %false_9101 = torch.constant.bool false + %7403 = torch.aten.arange.start_step %int0_9095, %int128_9096, %int2_9097, %int4_9098, %none_9099, %cpu_9100, %false_9101 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_9102 = torch.constant.int 6 + %7404 = torch.prims.convert_element_type %7403, %int6_9102 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_9103 = torch.constant.int 128 + %7405 = torch.aten.div.Scalar %7404, %int128_9103 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_9104 = torch.constant.float 5.000000e+05 + %7406 = torch.aten.pow.Scalar %float5.000000e05_9104, %7405 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7407 = torch.aten.reciprocal %7406 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_9105 = torch.constant.float 1.000000e+00 + %7408 = torch.aten.mul.Scalar %7407, %float1.000000e00_9105 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %7409 = torch.aten.reciprocal %7408 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_9106 = torch.constant.float 6.2831853071795862 + %7410 = torch.aten.mul.Scalar %7409, %float6.283190e00_9106 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_9107 = torch.constant.float 8.192000e+03 + %7411 = torch.aten.gt.Scalar %7410, %float8.192000e03_9107 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_9108 = torch.constant.int 8 + %7412 = torch.aten.div.Scalar %7408, %int8_9108 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7413 = torch.aten.where.self %7411, %7412, %7408 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7414 = torch.aten.reciprocal %7410 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_9109 = torch.constant.int 8192 + %7415 = torch.aten.mul.Scalar %7414, %int8192_9109 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_9110 = torch.constant.int 1 + %int1_9111 = torch.constant.int 1 + %7416 = torch.aten.sub.Scalar %7415, %int1_9110, %int1_9111 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_9112 = torch.constant.int 3 + %7417 = torch.aten.div.Scalar %7416, %int3_9112 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_9113 = torch.constant.int 1 + %int1_9114 = torch.constant.int 1 + %7418 = torch.aten.rsub.Scalar %7417, %int1_9113, %int1_9114 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %7419 = torch.aten.mul.Tensor %7418, %7413 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_9115 = torch.constant.int 8 + %7420 = torch.aten.div.Scalar %7419, %int8_9115 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7421 = torch.aten.mul.Tensor %7417, %7413 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_9116 = torch.constant.int 1 + %7422 = torch.aten.add.Tensor %7420, %7421, %int1_9116 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_9117 = torch.constant.float 2.048000e+03 + %7423 = torch.aten.lt.Scalar %7410, %float2.048000e03_9117 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7424 = torch.aten.bitwise_not %7423 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_9118 = torch.constant.float 8.192000e+03 + %7425 = torch.aten.gt.Scalar %7410, %float8.192000e03_9118 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7426 = torch.aten.bitwise_not %7425 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7427 = torch.aten.mul.Tensor %7424, %7426 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7428 = torch.aten.where.self %7427, %7422, %7413 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7429 = torch.prim.ListConstruct %7428, %7428 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_9119 = torch.constant.int -1 + %7430 = torch.aten.cat %7429, %int-1_9119 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_9120 = torch.constant.int 6 + %7431 = torch.prims.convert_element_type %7430, %int6_9120 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_9121 = torch.constant.int 1 + %7432 = torch.aten.unsqueeze %7402, %int1_9121 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_9122 = torch.constant.int 6 + %7433 = torch.prims.convert_element_type %7432, %int6_9122 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_9123 = torch.constant.int 0 + %7434 = torch.aten.unsqueeze %7431, %int0_9123 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_9124 = torch.constant.int 6 + %7435 = torch.prims.convert_element_type %7434, %int6_9124 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %7436 = torch.aten.mul.Tensor %7433, %7435 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %7437 = torch.aten.cos %7436 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_9125 = torch.constant.int 5 + %7438 = torch.prims.convert_element_type %7437, %int5_9125 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %7439 = torch.aten.sin %7436 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_9126 = torch.constant.int 5 + %7440 = torch.prims.convert_element_type %7439, %int5_9126 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_9127 = torch.constant.int 0 + %int0_9128 = torch.constant.int 0 + %int1_9129 = torch.constant.int 1 + %7441 = torch.aten.slice.Tensor %7438, %int0_9127, %int0_9128, %298, %int1_9129 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7441, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_9130 = torch.constant.int 1 + %int0_9131 = torch.constant.int 0 + %int9223372036854775807_9132 = torch.constant.int 9223372036854775807 + %int1_9133 = torch.constant.int 1 + %7442 = torch.aten.slice.Tensor %7441, %int1_9130, %int0_9131, %int9223372036854775807_9132, %int1_9133 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7442, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_9134 = torch.constant.int 0 + %int0_9135 = torch.constant.int 0 + %int1_9136 = torch.constant.int 1 + %7443 = torch.aten.slice.Tensor %7440, %int0_9134, %int0_9135, %298, %int1_9136 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7443, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_9137 = torch.constant.int 1 + %int0_9138 = torch.constant.int 0 + %int9223372036854775807_9139 = torch.constant.int 9223372036854775807 + %int1_9140 = torch.constant.int 1 + %7444 = torch.aten.slice.Tensor %7443, %int1_9137, %int0_9138, %int9223372036854775807_9139, %int1_9140 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7444, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_9141 = torch.constant.int 0 + %7445 = torch.aten.unsqueeze %7442, %int0_9141 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7445, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_9142 = torch.constant.int 1 + %int0_9143 = torch.constant.int 0 + %int9223372036854775807_9144 = torch.constant.int 9223372036854775807 + %int1_9145 = torch.constant.int 1 + %7446 = torch.aten.slice.Tensor %7445, %int1_9142, %int0_9143, %int9223372036854775807_9144, %int1_9145 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7446, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_9146 = torch.constant.int 2 + %7447 = torch.aten.unsqueeze %7446, %int2_9146 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7447, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_9147 = torch.constant.int 3 + %int0_9148 = torch.constant.int 0 + %int9223372036854775807_9149 = torch.constant.int 9223372036854775807 + %int1_9150 = torch.constant.int 1 + %7448 = torch.aten.slice.Tensor %7447, %int3_9147, %int0_9148, %int9223372036854775807_9149, %int1_9150 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7448, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_9151 = torch.constant.int 4 + %int1_9152 = torch.constant.int 1 + %int1_9153 = torch.constant.int 1 + %int1_9154 = torch.constant.int 1 + %7449 = torch.prim.ListConstruct %int4_9151, %int1_9152, %int1_9153, %int1_9154 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7450 = torch.aten.repeat %7448, %7449 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7450, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_9155 = torch.constant.int 0 + %7451 = torch.aten.unsqueeze %7444, %int0_9155 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7451, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_9156 = torch.constant.int 1 + %int0_9157 = torch.constant.int 0 + %int9223372036854775807_9158 = torch.constant.int 9223372036854775807 + %int1_9159 = torch.constant.int 1 + %7452 = torch.aten.slice.Tensor %7451, %int1_9156, %int0_9157, %int9223372036854775807_9158, %int1_9159 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7452, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_9160 = torch.constant.int 2 + %7453 = torch.aten.unsqueeze %7452, %int2_9160 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7453, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_9161 = torch.constant.int 3 + %int0_9162 = torch.constant.int 0 + %int9223372036854775807_9163 = torch.constant.int 9223372036854775807 + %int1_9164 = torch.constant.int 1 + %7454 = torch.aten.slice.Tensor %7453, %int3_9161, %int0_9162, %int9223372036854775807_9163, %int1_9164 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7454, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_9165 = torch.constant.int 4 + %int1_9166 = torch.constant.int 1 + %int1_9167 = torch.constant.int 1 + %int1_9168 = torch.constant.int 1 + %7455 = torch.prim.ListConstruct %int4_9165, %int1_9166, %int1_9167, %int1_9168 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7456 = torch.aten.repeat %7454, %7455 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7456, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %7457 = torch.aten.mul.Tensor %7336, %7450 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7457, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_9169 = torch.constant.int 3 + %int0_9170 = torch.constant.int 0 + %int64_9171 = torch.constant.int 64 + %int1_9172 = torch.constant.int 1 + %7458 = torch.aten.slice.Tensor %7336, %int3_9169, %int0_9170, %int64_9171, %int1_9172 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %7458, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_9173 = torch.constant.int 3 + %int64_9174 = torch.constant.int 64 + %int9223372036854775807_9175 = torch.constant.int 9223372036854775807 + %int1_9176 = torch.constant.int 1 + %7459 = torch.aten.slice.Tensor %7336, %int3_9173, %int64_9174, %int9223372036854775807_9175, %int1_9176 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %7459, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %7460 = torch.aten.neg %7459 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %7460, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %7461 = torch.prim.ListConstruct %7460, %7458 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_9177 = torch.constant.int -1 + %7462 = torch.aten.cat %7461, %int-1_9177 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7462, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %7463 = torch.aten.mul.Tensor %7462, %7456 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7463, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_9178 = torch.constant.int 1 + %7464 = torch.aten.add.Tensor %7457, %7463, %int1_9178 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7464, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_9179 = torch.constant.int 32 + %7465 = torch.aten.mul.Scalar %arg2, %int32_9179 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7465, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int26 = torch.constant.int 26 + %int1_9180 = torch.constant.int 1 + %7466 = torch.aten.add.Scalar %7465, %int26, %int1_9180 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7466, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_9181 = torch.constant.int 2 + %7467 = torch.aten.mul.Scalar %7466, %int2_9181 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7467, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_9182 = torch.constant.int 0 + %int1_9183 = torch.constant.int 1 + %7468 = torch.aten.add.Scalar %7467, %int0_9182, %int1_9183 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7468, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %7469 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %7470 = torch.aten.view %7468, %7469 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %7470, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_9184 = torch.constant.int 4 + %int32_9185 = torch.constant.int 32 + %int8_9186 = torch.constant.int 8 + %int128_9187 = torch.constant.int 128 + %7471 = torch.prim.ListConstruct %int4_9184, %296, %int32_9185, %int8_9186, %int128_9187 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7472 = torch.aten.view %7464, %7471 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %7472, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_9188 = torch.constant.int 32 + %int8_9189 = torch.constant.int 8 + %int128_9190 = torch.constant.int 128 + %7473 = torch.prim.ListConstruct %504, %int32_9188, %int8_9189, %int128_9190 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7474 = torch.aten.view %7472, %7473 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %7474, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_9191 = torch.constant.int 1 + %int2_9192 = torch.constant.int 2 + %7475 = torch.aten.transpose.int %7474, %int1_9191, %int2_9192 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7475, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_9193 = torch.constant.int 5 + %7476 = torch.prims.convert_element_type %7475, %int5_9193 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7476, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_9194 = torch.constant.int 32 + %int2_9195 = torch.constant.int 2 + %int8_9196 = torch.constant.int 8 + %int32_9197 = torch.constant.int 32 + %int128_9198 = torch.constant.int 128 + %7477 = torch.prim.ListConstruct %297, %int32_9194, %int2_9195, %int8_9196, %int32_9197, %int128_9198 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7478 = torch.aten.view %7240, %7477 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7478, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_9199 = torch.constant.int 8 + %int32_9200 = torch.constant.int 32 + %int128_9201 = torch.constant.int 128 + %7479 = torch.prim.ListConstruct %497, %int8_9199, %int32_9200, %int128_9201 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7480 = torch.aten.view %7478, %7479 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7480, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %7481 = torch.prim.ListConstruct %7470 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_9202 = torch.constant.bool false + %7482 = torch.aten.index_put %7480, %7481, %7476, %false_9202 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7482, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_9203 = torch.constant.int 32 + %int2_9204 = torch.constant.int 2 + %int8_9205 = torch.constant.int 8 + %int32_9206 = torch.constant.int 32 + %int128_9207 = torch.constant.int 128 + %7483 = torch.prim.ListConstruct %297, %int32_9203, %int2_9204, %int8_9205, %int32_9206, %int128_9207 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7484 = torch.aten.view %7482, %7483 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7484, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_9208 = torch.constant.int 2097152 + %7485 = torch.prim.ListConstruct %297, %int2097152_9208 : (!torch.int, !torch.int) -> !torch.list + %7486 = torch.aten.view %7484, %7485 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %7486, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_9209 = torch.constant.int 32 + %int2_9210 = torch.constant.int 2 + %int8_9211 = torch.constant.int 8 + %int32_9212 = torch.constant.int 32 + %int128_9213 = torch.constant.int 128 + %7487 = torch.prim.ListConstruct %297, %int32_9209, %int2_9210, %int8_9211, %int32_9212, %int128_9213 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7488 = torch.aten.view %7486, %7487 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7488, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_9214 = torch.constant.int 8 + %int32_9215 = torch.constant.int 32 + %int128_9216 = torch.constant.int 128 + %7489 = torch.prim.ListConstruct %497, %int8_9214, %int32_9215, %int128_9216 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7490 = torch.aten.view %7488, %7489 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7490, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_9217 = torch.constant.int 32 + %7491 = torch.aten.mul.Scalar %arg2, %int32_9217 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7491, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int26_9218 = torch.constant.int 26 + %int1_9219 = torch.constant.int 1 + %7492 = torch.aten.add.Scalar %7491, %int26_9218, %int1_9219 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7492, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_9220 = torch.constant.int 2 + %7493 = torch.aten.mul.Scalar %7492, %int2_9220 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7493, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_9221 = torch.constant.int 1 + %int1_9222 = torch.constant.int 1 + %7494 = torch.aten.add.Scalar %7493, %int1_9221, %int1_9222 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7494, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %7495 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %7496 = torch.aten.view %7494, %7495 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %7496, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_9223 = torch.constant.int 4 + %int32_9224 = torch.constant.int 32 + %int8_9225 = torch.constant.int 8 + %int128_9226 = torch.constant.int 128 + %7497 = torch.prim.ListConstruct %int4_9223, %296, %int32_9224, %int8_9225, %int128_9226 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7498 = torch.aten.view %7338, %7497 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %7498, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_9227 = torch.constant.int 32 + %int8_9228 = torch.constant.int 8 + %int128_9229 = torch.constant.int 128 + %7499 = torch.prim.ListConstruct %504, %int32_9227, %int8_9228, %int128_9229 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7500 = torch.aten.view %7498, %7499 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %7500, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_9230 = torch.constant.int 1 + %int2_9231 = torch.constant.int 2 + %7501 = torch.aten.transpose.int %7500, %int1_9230, %int2_9231 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7501, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_9232 = torch.constant.int 5 + %7502 = torch.prims.convert_element_type %7501, %int5_9232 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7502, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %7503 = torch.prim.ListConstruct %7496 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_9233 = torch.constant.bool false + %7504 = torch.aten.index_put %7490, %7503, %7502, %false_9233 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7504, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_9234 = torch.constant.int 32 + %int2_9235 = torch.constant.int 2 + %int8_9236 = torch.constant.int 8 + %int32_9237 = torch.constant.int 32 + %int128_9238 = torch.constant.int 128 + %7505 = torch.prim.ListConstruct %297, %int32_9234, %int2_9235, %int8_9236, %int32_9237, %int128_9238 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7506 = torch.aten.view %7504, %7505 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7506, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_9239 = torch.constant.int 2097152 + %7507 = torch.prim.ListConstruct %297, %int2097152_9239 : (!torch.int, !torch.int) -> !torch.list + %7508 = torch.aten.view %7506, %7507 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %7508, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_9240 = torch.constant.int -2 + %7509 = torch.aten.unsqueeze %7464, %int-2_9240 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %7509, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_9241 = torch.constant.int 4 + %int8_9242 = torch.constant.int 8 + %int4_9243 = torch.constant.int 4 + %int128_9244 = torch.constant.int 128 + %7510 = torch.prim.ListConstruct %int4_9241, %298, %int8_9242, %int4_9243, %int128_9244 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_9245 = torch.constant.bool false + %7511 = torch.aten.expand %7509, %7510, %false_9245 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7511, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_9246 = torch.constant.int 0 + %7512 = torch.aten.clone %7511, %int0_9246 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7512, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_9247 = torch.constant.int 4 + %int32_9248 = torch.constant.int 32 + %int128_9249 = torch.constant.int 128 + %7513 = torch.prim.ListConstruct %int4_9247, %298, %int32_9248, %int128_9249 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7514 = torch.aten._unsafe_view %7512, %7513 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7514, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_9250 = torch.constant.int -2 + %7515 = torch.aten.unsqueeze %7338, %int-2_9250 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %7515, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_9251 = torch.constant.int 4 + %int8_9252 = torch.constant.int 8 + %int4_9253 = torch.constant.int 4 + %int128_9254 = torch.constant.int 128 + %7516 = torch.prim.ListConstruct %int4_9251, %298, %int8_9252, %int4_9253, %int128_9254 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_9255 = torch.constant.bool false + %7517 = torch.aten.expand %7515, %7516, %false_9255 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7517, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_9256 = torch.constant.int 0 + %7518 = torch.aten.clone %7517, %int0_9256 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7518, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_9257 = torch.constant.int 4 + %int32_9258 = torch.constant.int 32 + %int128_9259 = torch.constant.int 128 + %7519 = torch.prim.ListConstruct %int4_9257, %298, %int32_9258, %int128_9259 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7520 = torch.aten._unsafe_view %7518, %7519 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7520, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_9260 = torch.constant.int 1 + %int2_9261 = torch.constant.int 2 + %7521 = torch.aten.transpose.int %7401, %int1_9260, %int2_9261 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %7521, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_9262 = torch.constant.int 1 + %int2_9263 = torch.constant.int 2 + %7522 = torch.aten.transpose.int %7514, %int1_9262, %int2_9263 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %7522, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_9264 = torch.constant.int 1 + %int2_9265 = torch.constant.int 2 + %7523 = torch.aten.transpose.int %7520, %int1_9264, %int2_9265 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %7523, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_9266 = torch.constant.float 0.000000e+00 + %false_9267 = torch.constant.bool false + %none_9268 = torch.constant.none + %7524:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%7521, %7522, %7523, %float0.000000e00_9266, %false_9267, %327, %none_9268) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %7524#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_9269 = torch.constant.int 1 + %int2_9270 = torch.constant.int 2 + %7525 = torch.aten.transpose.int %7524#0, %int1_9269, %int2_9270 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7525, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_9271 = torch.constant.int 4 + %int4096_9272 = torch.constant.int 4096 + %7526 = torch.prim.ListConstruct %int4_9271, %298, %int4096_9272 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7527 = torch.aten.view %7525, %7526 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7527, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_9273 = torch.constant.int -2 + %int-1_9274 = torch.constant.int -1 + %7528 = torch.aten.transpose.int %240, %int-2_9273, %int-1_9274 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_9275 = torch.constant.int 5 + %7529 = torch.prims.convert_element_type %7528, %int5_9275 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_9276 = torch.constant.int 4096 + %7530 = torch.prim.ListConstruct %342, %int4096_9276 : (!torch.int, !torch.int) -> !torch.list + %7531 = torch.aten.view %7527, %7530 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7531, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7532 = torch.aten.mm %7531, %7529 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7532, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_9277 = torch.constant.int 4 + %int4096_9278 = torch.constant.int 4096 + %7533 = torch.prim.ListConstruct %int4_9277, %298, %int4096_9278 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7534 = torch.aten.view %7532, %7533 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7534, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_9279 = torch.constant.int 1 + %7535 = torch.aten.add.Tensor %7301, %7534, %int1_9279 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7535, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_9280 = torch.constant.int 6 + %7536 = torch.prims.convert_element_type %7535, %int6_9280 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7536, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_9281 = torch.constant.int 2 + %7537 = torch.aten.pow.Tensor_Scalar %7536, %int2_9281 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7537, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_9282 = torch.constant.int -1 + %7538 = torch.prim.ListConstruct %int-1_9282 : (!torch.int) -> !torch.list + %true_9283 = torch.constant.bool true + %none_9284 = torch.constant.none + %7539 = torch.aten.mean.dim %7537, %7538, %true_9283, %none_9284 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7539, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_9285 = torch.constant.float 9.9999997473787516E-6 + %int1_9286 = torch.constant.int 1 + %7540 = torch.aten.add.Scalar %7539, %float9.999990e-06_9285, %int1_9286 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7540, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7541 = torch.aten.rsqrt %7540 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7541, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7542 = torch.aten.mul.Tensor %7536, %7541 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7542, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_9287 = torch.constant.int 5 + %7543 = torch.prims.convert_element_type %7542, %int5_9287 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7543, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %7544 = torch.aten.mul.Tensor %241, %7543 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7544, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_9288 = torch.constant.int 5 + %7545 = torch.prims.convert_element_type %7544, %int5_9288 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7545, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_9289 = torch.constant.int -2 + %int-1_9290 = torch.constant.int -1 + %7546 = torch.aten.transpose.int %242, %int-2_9289, %int-1_9290 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_9291 = torch.constant.int 5 + %7547 = torch.prims.convert_element_type %7546, %int5_9291 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_9292 = torch.constant.int 4096 + %7548 = torch.prim.ListConstruct %342, %int4096_9292 : (!torch.int, !torch.int) -> !torch.list + %7549 = torch.aten.view %7545, %7548 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7549, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7550 = torch.aten.mm %7549, %7547 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %7550, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_9293 = torch.constant.int 4 + %int14336_9294 = torch.constant.int 14336 + %7551 = torch.prim.ListConstruct %int4_9293, %298, %int14336_9294 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7552 = torch.aten.view %7550, %7551 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7552, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %7553 = torch.aten.silu %7552 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7553, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_9295 = torch.constant.int -2 + %int-1_9296 = torch.constant.int -1 + %7554 = torch.aten.transpose.int %243, %int-2_9295, %int-1_9296 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_9297 = torch.constant.int 5 + %7555 = torch.prims.convert_element_type %7554, %int5_9297 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_9298 = torch.constant.int 4096 + %7556 = torch.prim.ListConstruct %342, %int4096_9298 : (!torch.int, !torch.int) -> !torch.list + %7557 = torch.aten.view %7545, %7556 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7557, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7558 = torch.aten.mm %7557, %7555 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %7558, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_9299 = torch.constant.int 4 + %int14336_9300 = torch.constant.int 14336 + %7559 = torch.prim.ListConstruct %int4_9299, %298, %int14336_9300 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7560 = torch.aten.view %7558, %7559 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7560, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %7561 = torch.aten.mul.Tensor %7553, %7560 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7561, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_9301 = torch.constant.int -2 + %int-1_9302 = torch.constant.int -1 + %7562 = torch.aten.transpose.int %244, %int-2_9301, %int-1_9302 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_9303 = torch.constant.int 5 + %7563 = torch.prims.convert_element_type %7562, %int5_9303 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_9304 = torch.constant.int 14336 + %7564 = torch.prim.ListConstruct %342, %int14336_9304 : (!torch.int, !torch.int) -> !torch.list + %7565 = torch.aten.view %7561, %7564 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %7565, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %7566 = torch.aten.mm %7565, %7563 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7566, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_9305 = torch.constant.int 4 + %int4096_9306 = torch.constant.int 4096 + %7567 = torch.prim.ListConstruct %int4_9305, %298, %int4096_9306 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7568 = torch.aten.view %7566, %7567 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7568, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_9307 = torch.constant.int 1 + %7569 = torch.aten.add.Tensor %7535, %7568, %int1_9307 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7569, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_9308 = torch.constant.int 6 + %7570 = torch.prims.convert_element_type %7569, %int6_9308 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7570, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_9309 = torch.constant.int 2 + %7571 = torch.aten.pow.Tensor_Scalar %7570, %int2_9309 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7571, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_9310 = torch.constant.int -1 + %7572 = torch.prim.ListConstruct %int-1_9310 : (!torch.int) -> !torch.list + %true_9311 = torch.constant.bool true + %none_9312 = torch.constant.none + %7573 = torch.aten.mean.dim %7571, %7572, %true_9311, %none_9312 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7573, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_9313 = torch.constant.float 9.9999997473787516E-6 + %int1_9314 = torch.constant.int 1 + %7574 = torch.aten.add.Scalar %7573, %float9.999990e-06_9313, %int1_9314 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7574, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7575 = torch.aten.rsqrt %7574 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7575, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7576 = torch.aten.mul.Tensor %7570, %7575 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7576, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_9315 = torch.constant.int 5 + %7577 = torch.prims.convert_element_type %7576, %int5_9315 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7577, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %7578 = torch.aten.mul.Tensor %245, %7577 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7578, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_9316 = torch.constant.int 5 + %7579 = torch.prims.convert_element_type %7578, %int5_9316 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7579, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_9317 = torch.constant.int -2 + %int-1_9318 = torch.constant.int -1 + %7580 = torch.aten.transpose.int %246, %int-2_9317, %int-1_9318 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_9319 = torch.constant.int 5 + %7581 = torch.prims.convert_element_type %7580, %int5_9319 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_9320 = torch.constant.int 4096 + %7582 = torch.prim.ListConstruct %342, %int4096_9320 : (!torch.int, !torch.int) -> !torch.list + %7583 = torch.aten.view %7579, %7582 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7583, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7584 = torch.aten.mm %7583, %7581 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7584, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_9321 = torch.constant.int 4 + %int4096_9322 = torch.constant.int 4096 + %7585 = torch.prim.ListConstruct %int4_9321, %298, %int4096_9322 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7586 = torch.aten.view %7584, %7585 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7586, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_9323 = torch.constant.int -2 + %int-1_9324 = torch.constant.int -1 + %7587 = torch.aten.transpose.int %247, %int-2_9323, %int-1_9324 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_9325 = torch.constant.int 5 + %7588 = torch.prims.convert_element_type %7587, %int5_9325 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_9326 = torch.constant.int 4096 + %7589 = torch.prim.ListConstruct %342, %int4096_9326 : (!torch.int, !torch.int) -> !torch.list + %7590 = torch.aten.view %7579, %7589 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7590, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7591 = torch.aten.mm %7590, %7588 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %7591, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_9327 = torch.constant.int 4 + %int1024_9328 = torch.constant.int 1024 + %7592 = torch.prim.ListConstruct %int4_9327, %298, %int1024_9328 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7593 = torch.aten.view %7591, %7592 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %7593, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_9329 = torch.constant.int -2 + %int-1_9330 = torch.constant.int -1 + %7594 = torch.aten.transpose.int %248, %int-2_9329, %int-1_9330 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_9331 = torch.constant.int 5 + %7595 = torch.prims.convert_element_type %7594, %int5_9331 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_9332 = torch.constant.int 4096 + %7596 = torch.prim.ListConstruct %342, %int4096_9332 : (!torch.int, !torch.int) -> !torch.list + %7597 = torch.aten.view %7579, %7596 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7597, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7598 = torch.aten.mm %7597, %7595 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %7598, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_9333 = torch.constant.int 4 + %int1024_9334 = torch.constant.int 1024 + %7599 = torch.prim.ListConstruct %int4_9333, %298, %int1024_9334 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7600 = torch.aten.view %7598, %7599 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %7600, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_9335 = torch.constant.int 4 + %int32_9336 = torch.constant.int 32 + %int128_9337 = torch.constant.int 128 + %7601 = torch.prim.ListConstruct %int4_9335, %298, %int32_9336, %int128_9337 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7602 = torch.aten.view %7586, %7601 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7602, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_9338 = torch.constant.int 4 + %int8_9339 = torch.constant.int 8 + %int128_9340 = torch.constant.int 128 + %7603 = torch.prim.ListConstruct %int4_9338, %298, %int8_9339, %int128_9340 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7604 = torch.aten.view %7593, %7603 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7604, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_9341 = torch.constant.int 4 + %int8_9342 = torch.constant.int 8 + %int128_9343 = torch.constant.int 128 + %7605 = torch.prim.ListConstruct %int4_9341, %298, %int8_9342, %int128_9343 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7606 = torch.aten.view %7600, %7605 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7606, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_9344 = torch.constant.int 131072 + %none_9345 = torch.constant.none + %none_9346 = torch.constant.none + %cpu_9347 = torch.constant.device "cpu" + %false_9348 = torch.constant.bool false + %7607 = torch.aten.arange %int131072_9344, %none_9345, %none_9346, %cpu_9347, %false_9348 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_9349 = torch.constant.int 0 + %int128_9350 = torch.constant.int 128 + %int2_9351 = torch.constant.int 2 + %int4_9352 = torch.constant.int 4 + %none_9353 = torch.constant.none + %cpu_9354 = torch.constant.device "cpu" + %false_9355 = torch.constant.bool false + %7608 = torch.aten.arange.start_step %int0_9349, %int128_9350, %int2_9351, %int4_9352, %none_9353, %cpu_9354, %false_9355 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_9356 = torch.constant.int 6 + %7609 = torch.prims.convert_element_type %7608, %int6_9356 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_9357 = torch.constant.int 128 + %7610 = torch.aten.div.Scalar %7609, %int128_9357 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_9358 = torch.constant.float 5.000000e+05 + %7611 = torch.aten.pow.Scalar %float5.000000e05_9358, %7610 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7612 = torch.aten.reciprocal %7611 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_9359 = torch.constant.float 1.000000e+00 + %7613 = torch.aten.mul.Scalar %7612, %float1.000000e00_9359 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %7614 = torch.aten.reciprocal %7613 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_9360 = torch.constant.float 6.2831853071795862 + %7615 = torch.aten.mul.Scalar %7614, %float6.283190e00_9360 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_9361 = torch.constant.float 8.192000e+03 + %7616 = torch.aten.gt.Scalar %7615, %float8.192000e03_9361 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_9362 = torch.constant.int 8 + %7617 = torch.aten.div.Scalar %7613, %int8_9362 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7618 = torch.aten.where.self %7616, %7617, %7613 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7619 = torch.aten.reciprocal %7615 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_9363 = torch.constant.int 8192 + %7620 = torch.aten.mul.Scalar %7619, %int8192_9363 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_9364 = torch.constant.int 1 + %int1_9365 = torch.constant.int 1 + %7621 = torch.aten.sub.Scalar %7620, %int1_9364, %int1_9365 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_9366 = torch.constant.int 3 + %7622 = torch.aten.div.Scalar %7621, %int3_9366 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_9367 = torch.constant.int 1 + %int1_9368 = torch.constant.int 1 + %7623 = torch.aten.rsub.Scalar %7622, %int1_9367, %int1_9368 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %7624 = torch.aten.mul.Tensor %7623, %7618 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_9369 = torch.constant.int 8 + %7625 = torch.aten.div.Scalar %7624, %int8_9369 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7626 = torch.aten.mul.Tensor %7622, %7618 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_9370 = torch.constant.int 1 + %7627 = torch.aten.add.Tensor %7625, %7626, %int1_9370 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_9371 = torch.constant.float 2.048000e+03 + %7628 = torch.aten.lt.Scalar %7615, %float2.048000e03_9371 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7629 = torch.aten.bitwise_not %7628 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_9372 = torch.constant.float 8.192000e+03 + %7630 = torch.aten.gt.Scalar %7615, %float8.192000e03_9372 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7631 = torch.aten.bitwise_not %7630 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7632 = torch.aten.mul.Tensor %7629, %7631 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7633 = torch.aten.where.self %7632, %7627, %7618 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7634 = torch.prim.ListConstruct %7633, %7633 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_9373 = torch.constant.int -1 + %7635 = torch.aten.cat %7634, %int-1_9373 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_9374 = torch.constant.int 6 + %7636 = torch.prims.convert_element_type %7635, %int6_9374 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_9375 = torch.constant.int 1 + %7637 = torch.aten.unsqueeze %7607, %int1_9375 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_9376 = torch.constant.int 6 + %7638 = torch.prims.convert_element_type %7637, %int6_9376 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_9377 = torch.constant.int 0 + %7639 = torch.aten.unsqueeze %7636, %int0_9377 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_9378 = torch.constant.int 6 + %7640 = torch.prims.convert_element_type %7639, %int6_9378 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %7641 = torch.aten.mul.Tensor %7638, %7640 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %7642 = torch.aten.cos %7641 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_9379 = torch.constant.int 5 + %7643 = torch.prims.convert_element_type %7642, %int5_9379 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %7644 = torch.aten.sin %7641 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_9380 = torch.constant.int 5 + %7645 = torch.prims.convert_element_type %7644, %int5_9380 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_9381 = torch.constant.int 0 + %int0_9382 = torch.constant.int 0 + %int1_9383 = torch.constant.int 1 + %7646 = torch.aten.slice.Tensor %7643, %int0_9381, %int0_9382, %298, %int1_9383 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7646, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_9384 = torch.constant.int 1 + %int0_9385 = torch.constant.int 0 + %int9223372036854775807_9386 = torch.constant.int 9223372036854775807 + %int1_9387 = torch.constant.int 1 + %7647 = torch.aten.slice.Tensor %7646, %int1_9384, %int0_9385, %int9223372036854775807_9386, %int1_9387 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7647, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_9388 = torch.constant.int 0 + %int0_9389 = torch.constant.int 0 + %int1_9390 = torch.constant.int 1 + %7648 = torch.aten.slice.Tensor %7645, %int0_9388, %int0_9389, %298, %int1_9390 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7648, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_9391 = torch.constant.int 1 + %int0_9392 = torch.constant.int 0 + %int9223372036854775807_9393 = torch.constant.int 9223372036854775807 + %int1_9394 = torch.constant.int 1 + %7649 = torch.aten.slice.Tensor %7648, %int1_9391, %int0_9392, %int9223372036854775807_9393, %int1_9394 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7649, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_9395 = torch.constant.int 0 + %7650 = torch.aten.unsqueeze %7647, %int0_9395 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7650, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_9396 = torch.constant.int 1 + %int0_9397 = torch.constant.int 0 + %int9223372036854775807_9398 = torch.constant.int 9223372036854775807 + %int1_9399 = torch.constant.int 1 + %7651 = torch.aten.slice.Tensor %7650, %int1_9396, %int0_9397, %int9223372036854775807_9398, %int1_9399 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7651, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_9400 = torch.constant.int 2 + %7652 = torch.aten.unsqueeze %7651, %int2_9400 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7652, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_9401 = torch.constant.int 3 + %int0_9402 = torch.constant.int 0 + %int9223372036854775807_9403 = torch.constant.int 9223372036854775807 + %int1_9404 = torch.constant.int 1 + %7653 = torch.aten.slice.Tensor %7652, %int3_9401, %int0_9402, %int9223372036854775807_9403, %int1_9404 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7653, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_9405 = torch.constant.int 4 + %int1_9406 = torch.constant.int 1 + %int1_9407 = torch.constant.int 1 + %int1_9408 = torch.constant.int 1 + %7654 = torch.prim.ListConstruct %int4_9405, %int1_9406, %int1_9407, %int1_9408 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7655 = torch.aten.repeat %7653, %7654 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7655, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_9409 = torch.constant.int 0 + %7656 = torch.aten.unsqueeze %7649, %int0_9409 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7656, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_9410 = torch.constant.int 1 + %int0_9411 = torch.constant.int 0 + %int9223372036854775807_9412 = torch.constant.int 9223372036854775807 + %int1_9413 = torch.constant.int 1 + %7657 = torch.aten.slice.Tensor %7656, %int1_9410, %int0_9411, %int9223372036854775807_9412, %int1_9413 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7657, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_9414 = torch.constant.int 2 + %7658 = torch.aten.unsqueeze %7657, %int2_9414 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7658, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_9415 = torch.constant.int 3 + %int0_9416 = torch.constant.int 0 + %int9223372036854775807_9417 = torch.constant.int 9223372036854775807 + %int1_9418 = torch.constant.int 1 + %7659 = torch.aten.slice.Tensor %7658, %int3_9415, %int0_9416, %int9223372036854775807_9417, %int1_9418 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7659, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_9419 = torch.constant.int 4 + %int1_9420 = torch.constant.int 1 + %int1_9421 = torch.constant.int 1 + %int1_9422 = torch.constant.int 1 + %7660 = torch.prim.ListConstruct %int4_9419, %int1_9420, %int1_9421, %int1_9422 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7661 = torch.aten.repeat %7659, %7660 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7661, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %7662 = torch.aten.mul.Tensor %7602, %7655 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7662, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_9423 = torch.constant.int 3 + %int0_9424 = torch.constant.int 0 + %int64_9425 = torch.constant.int 64 + %int1_9426 = torch.constant.int 1 + %7663 = torch.aten.slice.Tensor %7602, %int3_9423, %int0_9424, %int64_9425, %int1_9426 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %7663, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_9427 = torch.constant.int 3 + %int64_9428 = torch.constant.int 64 + %int9223372036854775807_9429 = torch.constant.int 9223372036854775807 + %int1_9430 = torch.constant.int 1 + %7664 = torch.aten.slice.Tensor %7602, %int3_9427, %int64_9428, %int9223372036854775807_9429, %int1_9430 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %7664, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %7665 = torch.aten.neg %7664 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %7665, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %7666 = torch.prim.ListConstruct %7665, %7663 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_9431 = torch.constant.int -1 + %7667 = torch.aten.cat %7666, %int-1_9431 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7667, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %7668 = torch.aten.mul.Tensor %7667, %7661 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7668, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_9432 = torch.constant.int 1 + %7669 = torch.aten.add.Tensor %7662, %7668, %int1_9432 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7669, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_9433 = torch.constant.int 131072 + %none_9434 = torch.constant.none + %none_9435 = torch.constant.none + %cpu_9436 = torch.constant.device "cpu" + %false_9437 = torch.constant.bool false + %7670 = torch.aten.arange %int131072_9433, %none_9434, %none_9435, %cpu_9436, %false_9437 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_9438 = torch.constant.int 0 + %int128_9439 = torch.constant.int 128 + %int2_9440 = torch.constant.int 2 + %int4_9441 = torch.constant.int 4 + %none_9442 = torch.constant.none + %cpu_9443 = torch.constant.device "cpu" + %false_9444 = torch.constant.bool false + %7671 = torch.aten.arange.start_step %int0_9438, %int128_9439, %int2_9440, %int4_9441, %none_9442, %cpu_9443, %false_9444 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_9445 = torch.constant.int 6 + %7672 = torch.prims.convert_element_type %7671, %int6_9445 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_9446 = torch.constant.int 128 + %7673 = torch.aten.div.Scalar %7672, %int128_9446 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_9447 = torch.constant.float 5.000000e+05 + %7674 = torch.aten.pow.Scalar %float5.000000e05_9447, %7673 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7675 = torch.aten.reciprocal %7674 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_9448 = torch.constant.float 1.000000e+00 + %7676 = torch.aten.mul.Scalar %7675, %float1.000000e00_9448 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %7677 = torch.aten.reciprocal %7676 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_9449 = torch.constant.float 6.2831853071795862 + %7678 = torch.aten.mul.Scalar %7677, %float6.283190e00_9449 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_9450 = torch.constant.float 8.192000e+03 + %7679 = torch.aten.gt.Scalar %7678, %float8.192000e03_9450 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_9451 = torch.constant.int 8 + %7680 = torch.aten.div.Scalar %7676, %int8_9451 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7681 = torch.aten.where.self %7679, %7680, %7676 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7682 = torch.aten.reciprocal %7678 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_9452 = torch.constant.int 8192 + %7683 = torch.aten.mul.Scalar %7682, %int8192_9452 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_9453 = torch.constant.int 1 + %int1_9454 = torch.constant.int 1 + %7684 = torch.aten.sub.Scalar %7683, %int1_9453, %int1_9454 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_9455 = torch.constant.int 3 + %7685 = torch.aten.div.Scalar %7684, %int3_9455 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_9456 = torch.constant.int 1 + %int1_9457 = torch.constant.int 1 + %7686 = torch.aten.rsub.Scalar %7685, %int1_9456, %int1_9457 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %7687 = torch.aten.mul.Tensor %7686, %7681 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_9458 = torch.constant.int 8 + %7688 = torch.aten.div.Scalar %7687, %int8_9458 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7689 = torch.aten.mul.Tensor %7685, %7681 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_9459 = torch.constant.int 1 + %7690 = torch.aten.add.Tensor %7688, %7689, %int1_9459 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_9460 = torch.constant.float 2.048000e+03 + %7691 = torch.aten.lt.Scalar %7678, %float2.048000e03_9460 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7692 = torch.aten.bitwise_not %7691 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_9461 = torch.constant.float 8.192000e+03 + %7693 = torch.aten.gt.Scalar %7678, %float8.192000e03_9461 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7694 = torch.aten.bitwise_not %7693 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7695 = torch.aten.mul.Tensor %7692, %7694 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7696 = torch.aten.where.self %7695, %7690, %7681 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7697 = torch.prim.ListConstruct %7696, %7696 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_9462 = torch.constant.int -1 + %7698 = torch.aten.cat %7697, %int-1_9462 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_9463 = torch.constant.int 6 + %7699 = torch.prims.convert_element_type %7698, %int6_9463 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_9464 = torch.constant.int 1 + %7700 = torch.aten.unsqueeze %7670, %int1_9464 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_9465 = torch.constant.int 6 + %7701 = torch.prims.convert_element_type %7700, %int6_9465 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_9466 = torch.constant.int 0 + %7702 = torch.aten.unsqueeze %7699, %int0_9466 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_9467 = torch.constant.int 6 + %7703 = torch.prims.convert_element_type %7702, %int6_9467 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %7704 = torch.aten.mul.Tensor %7701, %7703 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %7705 = torch.aten.cos %7704 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_9468 = torch.constant.int 5 + %7706 = torch.prims.convert_element_type %7705, %int5_9468 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %7707 = torch.aten.sin %7704 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_9469 = torch.constant.int 5 + %7708 = torch.prims.convert_element_type %7707, %int5_9469 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_9470 = torch.constant.int 0 + %int0_9471 = torch.constant.int 0 + %int1_9472 = torch.constant.int 1 + %7709 = torch.aten.slice.Tensor %7706, %int0_9470, %int0_9471, %298, %int1_9472 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7709, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_9473 = torch.constant.int 1 + %int0_9474 = torch.constant.int 0 + %int9223372036854775807_9475 = torch.constant.int 9223372036854775807 + %int1_9476 = torch.constant.int 1 + %7710 = torch.aten.slice.Tensor %7709, %int1_9473, %int0_9474, %int9223372036854775807_9475, %int1_9476 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7710, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_9477 = torch.constant.int 0 + %int0_9478 = torch.constant.int 0 + %int1_9479 = torch.constant.int 1 + %7711 = torch.aten.slice.Tensor %7708, %int0_9477, %int0_9478, %298, %int1_9479 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7711, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_9480 = torch.constant.int 1 + %int0_9481 = torch.constant.int 0 + %int9223372036854775807_9482 = torch.constant.int 9223372036854775807 + %int1_9483 = torch.constant.int 1 + %7712 = torch.aten.slice.Tensor %7711, %int1_9480, %int0_9481, %int9223372036854775807_9482, %int1_9483 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7712, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_9484 = torch.constant.int 0 + %7713 = torch.aten.unsqueeze %7710, %int0_9484 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7713, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_9485 = torch.constant.int 1 + %int0_9486 = torch.constant.int 0 + %int9223372036854775807_9487 = torch.constant.int 9223372036854775807 + %int1_9488 = torch.constant.int 1 + %7714 = torch.aten.slice.Tensor %7713, %int1_9485, %int0_9486, %int9223372036854775807_9487, %int1_9488 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7714, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_9489 = torch.constant.int 2 + %7715 = torch.aten.unsqueeze %7714, %int2_9489 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7715, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_9490 = torch.constant.int 3 + %int0_9491 = torch.constant.int 0 + %int9223372036854775807_9492 = torch.constant.int 9223372036854775807 + %int1_9493 = torch.constant.int 1 + %7716 = torch.aten.slice.Tensor %7715, %int3_9490, %int0_9491, %int9223372036854775807_9492, %int1_9493 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7716, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_9494 = torch.constant.int 4 + %int1_9495 = torch.constant.int 1 + %int1_9496 = torch.constant.int 1 + %int1_9497 = torch.constant.int 1 + %7717 = torch.prim.ListConstruct %int4_9494, %int1_9495, %int1_9496, %int1_9497 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7718 = torch.aten.repeat %7716, %7717 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7718, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_9498 = torch.constant.int 0 + %7719 = torch.aten.unsqueeze %7712, %int0_9498 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7719, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_9499 = torch.constant.int 1 + %int0_9500 = torch.constant.int 0 + %int9223372036854775807_9501 = torch.constant.int 9223372036854775807 + %int1_9502 = torch.constant.int 1 + %7720 = torch.aten.slice.Tensor %7719, %int1_9499, %int0_9500, %int9223372036854775807_9501, %int1_9502 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7720, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_9503 = torch.constant.int 2 + %7721 = torch.aten.unsqueeze %7720, %int2_9503 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7721, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_9504 = torch.constant.int 3 + %int0_9505 = torch.constant.int 0 + %int9223372036854775807_9506 = torch.constant.int 9223372036854775807 + %int1_9507 = torch.constant.int 1 + %7722 = torch.aten.slice.Tensor %7721, %int3_9504, %int0_9505, %int9223372036854775807_9506, %int1_9507 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7722, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_9508 = torch.constant.int 4 + %int1_9509 = torch.constant.int 1 + %int1_9510 = torch.constant.int 1 + %int1_9511 = torch.constant.int 1 + %7723 = torch.prim.ListConstruct %int4_9508, %int1_9509, %int1_9510, %int1_9511 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7724 = torch.aten.repeat %7722, %7723 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7724, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %7725 = torch.aten.mul.Tensor %7604, %7718 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7725, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_9512 = torch.constant.int 3 + %int0_9513 = torch.constant.int 0 + %int64_9514 = torch.constant.int 64 + %int1_9515 = torch.constant.int 1 + %7726 = torch.aten.slice.Tensor %7604, %int3_9512, %int0_9513, %int64_9514, %int1_9515 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %7726, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_9516 = torch.constant.int 3 + %int64_9517 = torch.constant.int 64 + %int9223372036854775807_9518 = torch.constant.int 9223372036854775807 + %int1_9519 = torch.constant.int 1 + %7727 = torch.aten.slice.Tensor %7604, %int3_9516, %int64_9517, %int9223372036854775807_9518, %int1_9519 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %7727, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %7728 = torch.aten.neg %7727 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %7728, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %7729 = torch.prim.ListConstruct %7728, %7726 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_9520 = torch.constant.int -1 + %7730 = torch.aten.cat %7729, %int-1_9520 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7730, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %7731 = torch.aten.mul.Tensor %7730, %7724 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7731, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_9521 = torch.constant.int 1 + %7732 = torch.aten.add.Tensor %7725, %7731, %int1_9521 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7732, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_9522 = torch.constant.int 32 + %7733 = torch.aten.mul.Scalar %arg2, %int32_9522 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7733, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int27 = torch.constant.int 27 + %int1_9523 = torch.constant.int 1 + %7734 = torch.aten.add.Scalar %7733, %int27, %int1_9523 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7734, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_9524 = torch.constant.int 2 + %7735 = torch.aten.mul.Scalar %7734, %int2_9524 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7735, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_9525 = torch.constant.int 0 + %int1_9526 = torch.constant.int 1 + %7736 = torch.aten.add.Scalar %7735, %int0_9525, %int1_9526 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7736, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %7737 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %7738 = torch.aten.view %7736, %7737 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %7738, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_9527 = torch.constant.int 4 + %int32_9528 = torch.constant.int 32 + %int8_9529 = torch.constant.int 8 + %int128_9530 = torch.constant.int 128 + %7739 = torch.prim.ListConstruct %int4_9527, %296, %int32_9528, %int8_9529, %int128_9530 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7740 = torch.aten.view %7732, %7739 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %7740, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_9531 = torch.constant.int 32 + %int8_9532 = torch.constant.int 8 + %int128_9533 = torch.constant.int 128 + %7741 = torch.prim.ListConstruct %504, %int32_9531, %int8_9532, %int128_9533 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7742 = torch.aten.view %7740, %7741 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %7742, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_9534 = torch.constant.int 1 + %int2_9535 = torch.constant.int 2 + %7743 = torch.aten.transpose.int %7742, %int1_9534, %int2_9535 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7743, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_9536 = torch.constant.int 5 + %7744 = torch.prims.convert_element_type %7743, %int5_9536 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7744, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_9537 = torch.constant.int 32 + %int2_9538 = torch.constant.int 2 + %int8_9539 = torch.constant.int 8 + %int32_9540 = torch.constant.int 32 + %int128_9541 = torch.constant.int 128 + %7745 = torch.prim.ListConstruct %297, %int32_9537, %int2_9538, %int8_9539, %int32_9540, %int128_9541 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7746 = torch.aten.view %7508, %7745 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7746, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_9542 = torch.constant.int 8 + %int32_9543 = torch.constant.int 32 + %int128_9544 = torch.constant.int 128 + %7747 = torch.prim.ListConstruct %497, %int8_9542, %int32_9543, %int128_9544 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7748 = torch.aten.view %7746, %7747 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7748, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %7749 = torch.prim.ListConstruct %7738 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_9545 = torch.constant.bool false + %7750 = torch.aten.index_put %7748, %7749, %7744, %false_9545 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7750, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_9546 = torch.constant.int 32 + %int2_9547 = torch.constant.int 2 + %int8_9548 = torch.constant.int 8 + %int32_9549 = torch.constant.int 32 + %int128_9550 = torch.constant.int 128 + %7751 = torch.prim.ListConstruct %297, %int32_9546, %int2_9547, %int8_9548, %int32_9549, %int128_9550 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7752 = torch.aten.view %7750, %7751 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7752, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_9551 = torch.constant.int 2097152 + %7753 = torch.prim.ListConstruct %297, %int2097152_9551 : (!torch.int, !torch.int) -> !torch.list + %7754 = torch.aten.view %7752, %7753 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %7754, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_9552 = torch.constant.int 32 + %int2_9553 = torch.constant.int 2 + %int8_9554 = torch.constant.int 8 + %int32_9555 = torch.constant.int 32 + %int128_9556 = torch.constant.int 128 + %7755 = torch.prim.ListConstruct %297, %int32_9552, %int2_9553, %int8_9554, %int32_9555, %int128_9556 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7756 = torch.aten.view %7754, %7755 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7756, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_9557 = torch.constant.int 8 + %int32_9558 = torch.constant.int 32 + %int128_9559 = torch.constant.int 128 + %7757 = torch.prim.ListConstruct %497, %int8_9557, %int32_9558, %int128_9559 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7758 = torch.aten.view %7756, %7757 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7758, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_9560 = torch.constant.int 32 + %7759 = torch.aten.mul.Scalar %arg2, %int32_9560 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7759, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int27_9561 = torch.constant.int 27 + %int1_9562 = torch.constant.int 1 + %7760 = torch.aten.add.Scalar %7759, %int27_9561, %int1_9562 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7760, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_9563 = torch.constant.int 2 + %7761 = torch.aten.mul.Scalar %7760, %int2_9563 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7761, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_9564 = torch.constant.int 1 + %int1_9565 = torch.constant.int 1 + %7762 = torch.aten.add.Scalar %7761, %int1_9564, %int1_9565 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %7762, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %7763 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %7764 = torch.aten.view %7762, %7763 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %7764, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_9566 = torch.constant.int 4 + %int32_9567 = torch.constant.int 32 + %int8_9568 = torch.constant.int 8 + %int128_9569 = torch.constant.int 128 + %7765 = torch.prim.ListConstruct %int4_9566, %296, %int32_9567, %int8_9568, %int128_9569 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7766 = torch.aten.view %7606, %7765 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %7766, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_9570 = torch.constant.int 32 + %int8_9571 = torch.constant.int 8 + %int128_9572 = torch.constant.int 128 + %7767 = torch.prim.ListConstruct %504, %int32_9570, %int8_9571, %int128_9572 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7768 = torch.aten.view %7766, %7767 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %7768, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_9573 = torch.constant.int 1 + %int2_9574 = torch.constant.int 2 + %7769 = torch.aten.transpose.int %7768, %int1_9573, %int2_9574 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7769, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_9575 = torch.constant.int 5 + %7770 = torch.prims.convert_element_type %7769, %int5_9575 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7770, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %7771 = torch.prim.ListConstruct %7764 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_9576 = torch.constant.bool false + %7772 = torch.aten.index_put %7758, %7771, %7770, %false_9576 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %7772, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_9577 = torch.constant.int 32 + %int2_9578 = torch.constant.int 2 + %int8_9579 = torch.constant.int 8 + %int32_9580 = torch.constant.int 32 + %int128_9581 = torch.constant.int 128 + %7773 = torch.prim.ListConstruct %297, %int32_9577, %int2_9578, %int8_9579, %int32_9580, %int128_9581 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7774 = torch.aten.view %7772, %7773 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7774, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_9582 = torch.constant.int 2097152 + %7775 = torch.prim.ListConstruct %297, %int2097152_9582 : (!torch.int, !torch.int) -> !torch.list + %7776 = torch.aten.view %7774, %7775 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %7776, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_9583 = torch.constant.int -2 + %7777 = torch.aten.unsqueeze %7732, %int-2_9583 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %7777, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_9584 = torch.constant.int 4 + %int8_9585 = torch.constant.int 8 + %int4_9586 = torch.constant.int 4 + %int128_9587 = torch.constant.int 128 + %7778 = torch.prim.ListConstruct %int4_9584, %298, %int8_9585, %int4_9586, %int128_9587 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_9588 = torch.constant.bool false + %7779 = torch.aten.expand %7777, %7778, %false_9588 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7779, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_9589 = torch.constant.int 0 + %7780 = torch.aten.clone %7779, %int0_9589 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7780, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_9590 = torch.constant.int 4 + %int32_9591 = torch.constant.int 32 + %int128_9592 = torch.constant.int 128 + %7781 = torch.prim.ListConstruct %int4_9590, %298, %int32_9591, %int128_9592 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7782 = torch.aten._unsafe_view %7780, %7781 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7782, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_9593 = torch.constant.int -2 + %7783 = torch.aten.unsqueeze %7606, %int-2_9593 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %7783, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_9594 = torch.constant.int 4 + %int8_9595 = torch.constant.int 8 + %int4_9596 = torch.constant.int 4 + %int128_9597 = torch.constant.int 128 + %7784 = torch.prim.ListConstruct %int4_9594, %298, %int8_9595, %int4_9596, %int128_9597 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_9598 = torch.constant.bool false + %7785 = torch.aten.expand %7783, %7784, %false_9598 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7785, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_9599 = torch.constant.int 0 + %7786 = torch.aten.clone %7785, %int0_9599 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7786, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_9600 = torch.constant.int 4 + %int32_9601 = torch.constant.int 32 + %int128_9602 = torch.constant.int 128 + %7787 = torch.prim.ListConstruct %int4_9600, %298, %int32_9601, %int128_9602 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7788 = torch.aten._unsafe_view %7786, %7787 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7788, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_9603 = torch.constant.int 1 + %int2_9604 = torch.constant.int 2 + %7789 = torch.aten.transpose.int %7669, %int1_9603, %int2_9604 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %7789, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_9605 = torch.constant.int 1 + %int2_9606 = torch.constant.int 2 + %7790 = torch.aten.transpose.int %7782, %int1_9605, %int2_9606 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %7790, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_9607 = torch.constant.int 1 + %int2_9608 = torch.constant.int 2 + %7791 = torch.aten.transpose.int %7788, %int1_9607, %int2_9608 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %7791, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_9609 = torch.constant.float 0.000000e+00 + %false_9610 = torch.constant.bool false + %none_9611 = torch.constant.none + %7792:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%7789, %7790, %7791, %float0.000000e00_9609, %false_9610, %327, %none_9611) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %7792#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_9612 = torch.constant.int 1 + %int2_9613 = torch.constant.int 2 + %7793 = torch.aten.transpose.int %7792#0, %int1_9612, %int2_9613 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7793, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_9614 = torch.constant.int 4 + %int4096_9615 = torch.constant.int 4096 + %7794 = torch.prim.ListConstruct %int4_9614, %298, %int4096_9615 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7795 = torch.aten.view %7793, %7794 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7795, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_9616 = torch.constant.int -2 + %int-1_9617 = torch.constant.int -1 + %7796 = torch.aten.transpose.int %249, %int-2_9616, %int-1_9617 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_9618 = torch.constant.int 5 + %7797 = torch.prims.convert_element_type %7796, %int5_9618 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_9619 = torch.constant.int 4096 + %7798 = torch.prim.ListConstruct %342, %int4096_9619 : (!torch.int, !torch.int) -> !torch.list + %7799 = torch.aten.view %7795, %7798 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7799, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7800 = torch.aten.mm %7799, %7797 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7800, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_9620 = torch.constant.int 4 + %int4096_9621 = torch.constant.int 4096 + %7801 = torch.prim.ListConstruct %int4_9620, %298, %int4096_9621 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7802 = torch.aten.view %7800, %7801 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7802, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_9622 = torch.constant.int 1 + %7803 = torch.aten.add.Tensor %7569, %7802, %int1_9622 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7803, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_9623 = torch.constant.int 6 + %7804 = torch.prims.convert_element_type %7803, %int6_9623 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7804, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_9624 = torch.constant.int 2 + %7805 = torch.aten.pow.Tensor_Scalar %7804, %int2_9624 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7805, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_9625 = torch.constant.int -1 + %7806 = torch.prim.ListConstruct %int-1_9625 : (!torch.int) -> !torch.list + %true_9626 = torch.constant.bool true + %none_9627 = torch.constant.none + %7807 = torch.aten.mean.dim %7805, %7806, %true_9626, %none_9627 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7807, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_9628 = torch.constant.float 9.9999997473787516E-6 + %int1_9629 = torch.constant.int 1 + %7808 = torch.aten.add.Scalar %7807, %float9.999990e-06_9628, %int1_9629 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7808, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7809 = torch.aten.rsqrt %7808 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7809, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7810 = torch.aten.mul.Tensor %7804, %7809 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7810, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_9630 = torch.constant.int 5 + %7811 = torch.prims.convert_element_type %7810, %int5_9630 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7811, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %7812 = torch.aten.mul.Tensor %250, %7811 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7812, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_9631 = torch.constant.int 5 + %7813 = torch.prims.convert_element_type %7812, %int5_9631 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7813, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_9632 = torch.constant.int -2 + %int-1_9633 = torch.constant.int -1 + %7814 = torch.aten.transpose.int %251, %int-2_9632, %int-1_9633 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_9634 = torch.constant.int 5 + %7815 = torch.prims.convert_element_type %7814, %int5_9634 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_9635 = torch.constant.int 4096 + %7816 = torch.prim.ListConstruct %342, %int4096_9635 : (!torch.int, !torch.int) -> !torch.list + %7817 = torch.aten.view %7813, %7816 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7817, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7818 = torch.aten.mm %7817, %7815 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %7818, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_9636 = torch.constant.int 4 + %int14336_9637 = torch.constant.int 14336 + %7819 = torch.prim.ListConstruct %int4_9636, %298, %int14336_9637 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7820 = torch.aten.view %7818, %7819 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7820, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %7821 = torch.aten.silu %7820 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7821, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_9638 = torch.constant.int -2 + %int-1_9639 = torch.constant.int -1 + %7822 = torch.aten.transpose.int %252, %int-2_9638, %int-1_9639 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_9640 = torch.constant.int 5 + %7823 = torch.prims.convert_element_type %7822, %int5_9640 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_9641 = torch.constant.int 4096 + %7824 = torch.prim.ListConstruct %342, %int4096_9641 : (!torch.int, !torch.int) -> !torch.list + %7825 = torch.aten.view %7813, %7824 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7825, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7826 = torch.aten.mm %7825, %7823 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %7826, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_9642 = torch.constant.int 4 + %int14336_9643 = torch.constant.int 14336 + %7827 = torch.prim.ListConstruct %int4_9642, %298, %int14336_9643 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7828 = torch.aten.view %7826, %7827 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7828, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %7829 = torch.aten.mul.Tensor %7821, %7828 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %7829, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_9644 = torch.constant.int -2 + %int-1_9645 = torch.constant.int -1 + %7830 = torch.aten.transpose.int %253, %int-2_9644, %int-1_9645 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_9646 = torch.constant.int 5 + %7831 = torch.prims.convert_element_type %7830, %int5_9646 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_9647 = torch.constant.int 14336 + %7832 = torch.prim.ListConstruct %342, %int14336_9647 : (!torch.int, !torch.int) -> !torch.list + %7833 = torch.aten.view %7829, %7832 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %7833, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %7834 = torch.aten.mm %7833, %7831 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7834, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_9648 = torch.constant.int 4 + %int4096_9649 = torch.constant.int 4096 + %7835 = torch.prim.ListConstruct %int4_9648, %298, %int4096_9649 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7836 = torch.aten.view %7834, %7835 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7836, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_9650 = torch.constant.int 1 + %7837 = torch.aten.add.Tensor %7803, %7836, %int1_9650 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7837, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_9651 = torch.constant.int 6 + %7838 = torch.prims.convert_element_type %7837, %int6_9651 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7838, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_9652 = torch.constant.int 2 + %7839 = torch.aten.pow.Tensor_Scalar %7838, %int2_9652 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7839, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_9653 = torch.constant.int -1 + %7840 = torch.prim.ListConstruct %int-1_9653 : (!torch.int) -> !torch.list + %true_9654 = torch.constant.bool true + %none_9655 = torch.constant.none + %7841 = torch.aten.mean.dim %7839, %7840, %true_9654, %none_9655 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7841, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_9656 = torch.constant.float 9.9999997473787516E-6 + %int1_9657 = torch.constant.int 1 + %7842 = torch.aten.add.Scalar %7841, %float9.999990e-06_9656, %int1_9657 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7842, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7843 = torch.aten.rsqrt %7842 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %7843, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %7844 = torch.aten.mul.Tensor %7838, %7843 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7844, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_9658 = torch.constant.int 5 + %7845 = torch.prims.convert_element_type %7844, %int5_9658 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7845, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %7846 = torch.aten.mul.Tensor %254, %7845 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %7846, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_9659 = torch.constant.int 5 + %7847 = torch.prims.convert_element_type %7846, %int5_9659 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7847, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_9660 = torch.constant.int -2 + %int-1_9661 = torch.constant.int -1 + %7848 = torch.aten.transpose.int %255, %int-2_9660, %int-1_9661 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_9662 = torch.constant.int 5 + %7849 = torch.prims.convert_element_type %7848, %int5_9662 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_9663 = torch.constant.int 4096 + %7850 = torch.prim.ListConstruct %342, %int4096_9663 : (!torch.int, !torch.int) -> !torch.list + %7851 = torch.aten.view %7847, %7850 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7851, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7852 = torch.aten.mm %7851, %7849 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7852, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_9664 = torch.constant.int 4 + %int4096_9665 = torch.constant.int 4096 + %7853 = torch.prim.ListConstruct %int4_9664, %298, %int4096_9665 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7854 = torch.aten.view %7852, %7853 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %7854, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_9666 = torch.constant.int -2 + %int-1_9667 = torch.constant.int -1 + %7855 = torch.aten.transpose.int %256, %int-2_9666, %int-1_9667 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_9668 = torch.constant.int 5 + %7856 = torch.prims.convert_element_type %7855, %int5_9668 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_9669 = torch.constant.int 4096 + %7857 = torch.prim.ListConstruct %342, %int4096_9669 : (!torch.int, !torch.int) -> !torch.list + %7858 = torch.aten.view %7847, %7857 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7858, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7859 = torch.aten.mm %7858, %7856 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %7859, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_9670 = torch.constant.int 4 + %int1024_9671 = torch.constant.int 1024 + %7860 = torch.prim.ListConstruct %int4_9670, %298, %int1024_9671 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7861 = torch.aten.view %7859, %7860 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %7861, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_9672 = torch.constant.int -2 + %int-1_9673 = torch.constant.int -1 + %7862 = torch.aten.transpose.int %257, %int-2_9672, %int-1_9673 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_9674 = torch.constant.int 5 + %7863 = torch.prims.convert_element_type %7862, %int5_9674 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_9675 = torch.constant.int 4096 + %7864 = torch.prim.ListConstruct %342, %int4096_9675 : (!torch.int, !torch.int) -> !torch.list + %7865 = torch.aten.view %7847, %7864 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %7865, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %7866 = torch.aten.mm %7865, %7863 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %7866, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_9676 = torch.constant.int 4 + %int1024_9677 = torch.constant.int 1024 + %7867 = torch.prim.ListConstruct %int4_9676, %298, %int1024_9677 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7868 = torch.aten.view %7866, %7867 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %7868, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_9678 = torch.constant.int 4 + %int32_9679 = torch.constant.int 32 + %int128_9680 = torch.constant.int 128 + %7869 = torch.prim.ListConstruct %int4_9678, %298, %int32_9679, %int128_9680 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7870 = torch.aten.view %7854, %7869 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7870, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_9681 = torch.constant.int 4 + %int8_9682 = torch.constant.int 8 + %int128_9683 = torch.constant.int 128 + %7871 = torch.prim.ListConstruct %int4_9681, %298, %int8_9682, %int128_9683 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7872 = torch.aten.view %7861, %7871 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7872, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_9684 = torch.constant.int 4 + %int8_9685 = torch.constant.int 8 + %int128_9686 = torch.constant.int 128 + %7873 = torch.prim.ListConstruct %int4_9684, %298, %int8_9685, %int128_9686 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7874 = torch.aten.view %7868, %7873 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7874, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_9687 = torch.constant.int 131072 + %none_9688 = torch.constant.none + %none_9689 = torch.constant.none + %cpu_9690 = torch.constant.device "cpu" + %false_9691 = torch.constant.bool false + %7875 = torch.aten.arange %int131072_9687, %none_9688, %none_9689, %cpu_9690, %false_9691 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_9692 = torch.constant.int 0 + %int128_9693 = torch.constant.int 128 + %int2_9694 = torch.constant.int 2 + %int4_9695 = torch.constant.int 4 + %none_9696 = torch.constant.none + %cpu_9697 = torch.constant.device "cpu" + %false_9698 = torch.constant.bool false + %7876 = torch.aten.arange.start_step %int0_9692, %int128_9693, %int2_9694, %int4_9695, %none_9696, %cpu_9697, %false_9698 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_9699 = torch.constant.int 6 + %7877 = torch.prims.convert_element_type %7876, %int6_9699 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_9700 = torch.constant.int 128 + %7878 = torch.aten.div.Scalar %7877, %int128_9700 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_9701 = torch.constant.float 5.000000e+05 + %7879 = torch.aten.pow.Scalar %float5.000000e05_9701, %7878 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7880 = torch.aten.reciprocal %7879 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_9702 = torch.constant.float 1.000000e+00 + %7881 = torch.aten.mul.Scalar %7880, %float1.000000e00_9702 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %7882 = torch.aten.reciprocal %7881 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_9703 = torch.constant.float 6.2831853071795862 + %7883 = torch.aten.mul.Scalar %7882, %float6.283190e00_9703 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_9704 = torch.constant.float 8.192000e+03 + %7884 = torch.aten.gt.Scalar %7883, %float8.192000e03_9704 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_9705 = torch.constant.int 8 + %7885 = torch.aten.div.Scalar %7881, %int8_9705 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7886 = torch.aten.where.self %7884, %7885, %7881 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7887 = torch.aten.reciprocal %7883 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_9706 = torch.constant.int 8192 + %7888 = torch.aten.mul.Scalar %7887, %int8192_9706 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_9707 = torch.constant.int 1 + %int1_9708 = torch.constant.int 1 + %7889 = torch.aten.sub.Scalar %7888, %int1_9707, %int1_9708 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_9709 = torch.constant.int 3 + %7890 = torch.aten.div.Scalar %7889, %int3_9709 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_9710 = torch.constant.int 1 + %int1_9711 = torch.constant.int 1 + %7891 = torch.aten.rsub.Scalar %7890, %int1_9710, %int1_9711 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %7892 = torch.aten.mul.Tensor %7891, %7886 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_9712 = torch.constant.int 8 + %7893 = torch.aten.div.Scalar %7892, %int8_9712 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7894 = torch.aten.mul.Tensor %7890, %7886 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_9713 = torch.constant.int 1 + %7895 = torch.aten.add.Tensor %7893, %7894, %int1_9713 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_9714 = torch.constant.float 2.048000e+03 + %7896 = torch.aten.lt.Scalar %7883, %float2.048000e03_9714 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7897 = torch.aten.bitwise_not %7896 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_9715 = torch.constant.float 8.192000e+03 + %7898 = torch.aten.gt.Scalar %7883, %float8.192000e03_9715 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7899 = torch.aten.bitwise_not %7898 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7900 = torch.aten.mul.Tensor %7897, %7899 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7901 = torch.aten.where.self %7900, %7895, %7886 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7902 = torch.prim.ListConstruct %7901, %7901 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_9716 = torch.constant.int -1 + %7903 = torch.aten.cat %7902, %int-1_9716 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_9717 = torch.constant.int 6 + %7904 = torch.prims.convert_element_type %7903, %int6_9717 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_9718 = torch.constant.int 1 + %7905 = torch.aten.unsqueeze %7875, %int1_9718 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_9719 = torch.constant.int 6 + %7906 = torch.prims.convert_element_type %7905, %int6_9719 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_9720 = torch.constant.int 0 + %7907 = torch.aten.unsqueeze %7904, %int0_9720 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_9721 = torch.constant.int 6 + %7908 = torch.prims.convert_element_type %7907, %int6_9721 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %7909 = torch.aten.mul.Tensor %7906, %7908 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %7910 = torch.aten.cos %7909 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_9722 = torch.constant.int 5 + %7911 = torch.prims.convert_element_type %7910, %int5_9722 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %7912 = torch.aten.sin %7909 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_9723 = torch.constant.int 5 + %7913 = torch.prims.convert_element_type %7912, %int5_9723 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_9724 = torch.constant.int 0 + %int0_9725 = torch.constant.int 0 + %int1_9726 = torch.constant.int 1 + %7914 = torch.aten.slice.Tensor %7911, %int0_9724, %int0_9725, %298, %int1_9726 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7914, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_9727 = torch.constant.int 1 + %int0_9728 = torch.constant.int 0 + %int9223372036854775807_9729 = torch.constant.int 9223372036854775807 + %int1_9730 = torch.constant.int 1 + %7915 = torch.aten.slice.Tensor %7914, %int1_9727, %int0_9728, %int9223372036854775807_9729, %int1_9730 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7915, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_9731 = torch.constant.int 0 + %int0_9732 = torch.constant.int 0 + %int1_9733 = torch.constant.int 1 + %7916 = torch.aten.slice.Tensor %7913, %int0_9731, %int0_9732, %298, %int1_9733 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7916, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_9734 = torch.constant.int 1 + %int0_9735 = torch.constant.int 0 + %int9223372036854775807_9736 = torch.constant.int 9223372036854775807 + %int1_9737 = torch.constant.int 1 + %7917 = torch.aten.slice.Tensor %7916, %int1_9734, %int0_9735, %int9223372036854775807_9736, %int1_9737 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7917, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_9738 = torch.constant.int 0 + %7918 = torch.aten.unsqueeze %7915, %int0_9738 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7918, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_9739 = torch.constant.int 1 + %int0_9740 = torch.constant.int 0 + %int9223372036854775807_9741 = torch.constant.int 9223372036854775807 + %int1_9742 = torch.constant.int 1 + %7919 = torch.aten.slice.Tensor %7918, %int1_9739, %int0_9740, %int9223372036854775807_9741, %int1_9742 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7919, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_9743 = torch.constant.int 2 + %7920 = torch.aten.unsqueeze %7919, %int2_9743 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7920, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_9744 = torch.constant.int 3 + %int0_9745 = torch.constant.int 0 + %int9223372036854775807_9746 = torch.constant.int 9223372036854775807 + %int1_9747 = torch.constant.int 1 + %7921 = torch.aten.slice.Tensor %7920, %int3_9744, %int0_9745, %int9223372036854775807_9746, %int1_9747 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7921, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_9748 = torch.constant.int 4 + %int1_9749 = torch.constant.int 1 + %int1_9750 = torch.constant.int 1 + %int1_9751 = torch.constant.int 1 + %7922 = torch.prim.ListConstruct %int4_9748, %int1_9749, %int1_9750, %int1_9751 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7923 = torch.aten.repeat %7921, %7922 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7923, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_9752 = torch.constant.int 0 + %7924 = torch.aten.unsqueeze %7917, %int0_9752 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7924, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_9753 = torch.constant.int 1 + %int0_9754 = torch.constant.int 0 + %int9223372036854775807_9755 = torch.constant.int 9223372036854775807 + %int1_9756 = torch.constant.int 1 + %7925 = torch.aten.slice.Tensor %7924, %int1_9753, %int0_9754, %int9223372036854775807_9755, %int1_9756 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7925, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_9757 = torch.constant.int 2 + %7926 = torch.aten.unsqueeze %7925, %int2_9757 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7926, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_9758 = torch.constant.int 3 + %int0_9759 = torch.constant.int 0 + %int9223372036854775807_9760 = torch.constant.int 9223372036854775807 + %int1_9761 = torch.constant.int 1 + %7927 = torch.aten.slice.Tensor %7926, %int3_9758, %int0_9759, %int9223372036854775807_9760, %int1_9761 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7927, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_9762 = torch.constant.int 4 + %int1_9763 = torch.constant.int 1 + %int1_9764 = torch.constant.int 1 + %int1_9765 = torch.constant.int 1 + %7928 = torch.prim.ListConstruct %int4_9762, %int1_9763, %int1_9764, %int1_9765 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7929 = torch.aten.repeat %7927, %7928 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7929, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %7930 = torch.aten.mul.Tensor %7870, %7923 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7930, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_9766 = torch.constant.int 3 + %int0_9767 = torch.constant.int 0 + %int64_9768 = torch.constant.int 64 + %int1_9769 = torch.constant.int 1 + %7931 = torch.aten.slice.Tensor %7870, %int3_9766, %int0_9767, %int64_9768, %int1_9769 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %7931, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_9770 = torch.constant.int 3 + %int64_9771 = torch.constant.int 64 + %int9223372036854775807_9772 = torch.constant.int 9223372036854775807 + %int1_9773 = torch.constant.int 1 + %7932 = torch.aten.slice.Tensor %7870, %int3_9770, %int64_9771, %int9223372036854775807_9772, %int1_9773 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %7932, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %7933 = torch.aten.neg %7932 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %7933, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %7934 = torch.prim.ListConstruct %7933, %7931 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_9774 = torch.constant.int -1 + %7935 = torch.aten.cat %7934, %int-1_9774 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7935, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %7936 = torch.aten.mul.Tensor %7935, %7929 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7936, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_9775 = torch.constant.int 1 + %7937 = torch.aten.add.Tensor %7930, %7936, %int1_9775 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7937, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_9776 = torch.constant.int 131072 + %none_9777 = torch.constant.none + %none_9778 = torch.constant.none + %cpu_9779 = torch.constant.device "cpu" + %false_9780 = torch.constant.bool false + %7938 = torch.aten.arange %int131072_9776, %none_9777, %none_9778, %cpu_9779, %false_9780 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_9781 = torch.constant.int 0 + %int128_9782 = torch.constant.int 128 + %int2_9783 = torch.constant.int 2 + %int4_9784 = torch.constant.int 4 + %none_9785 = torch.constant.none + %cpu_9786 = torch.constant.device "cpu" + %false_9787 = torch.constant.bool false + %7939 = torch.aten.arange.start_step %int0_9781, %int128_9782, %int2_9783, %int4_9784, %none_9785, %cpu_9786, %false_9787 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_9788 = torch.constant.int 6 + %7940 = torch.prims.convert_element_type %7939, %int6_9788 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_9789 = torch.constant.int 128 + %7941 = torch.aten.div.Scalar %7940, %int128_9789 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_9790 = torch.constant.float 5.000000e+05 + %7942 = torch.aten.pow.Scalar %float5.000000e05_9790, %7941 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7943 = torch.aten.reciprocal %7942 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_9791 = torch.constant.float 1.000000e+00 + %7944 = torch.aten.mul.Scalar %7943, %float1.000000e00_9791 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %7945 = torch.aten.reciprocal %7944 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_9792 = torch.constant.float 6.2831853071795862 + %7946 = torch.aten.mul.Scalar %7945, %float6.283190e00_9792 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_9793 = torch.constant.float 8.192000e+03 + %7947 = torch.aten.gt.Scalar %7946, %float8.192000e03_9793 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_9794 = torch.constant.int 8 + %7948 = torch.aten.div.Scalar %7944, %int8_9794 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7949 = torch.aten.where.self %7947, %7948, %7944 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7950 = torch.aten.reciprocal %7946 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_9795 = torch.constant.int 8192 + %7951 = torch.aten.mul.Scalar %7950, %int8192_9795 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_9796 = torch.constant.int 1 + %int1_9797 = torch.constant.int 1 + %7952 = torch.aten.sub.Scalar %7951, %int1_9796, %int1_9797 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_9798 = torch.constant.int 3 + %7953 = torch.aten.div.Scalar %7952, %int3_9798 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_9799 = torch.constant.int 1 + %int1_9800 = torch.constant.int 1 + %7954 = torch.aten.rsub.Scalar %7953, %int1_9799, %int1_9800 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %7955 = torch.aten.mul.Tensor %7954, %7949 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_9801 = torch.constant.int 8 + %7956 = torch.aten.div.Scalar %7955, %int8_9801 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %7957 = torch.aten.mul.Tensor %7953, %7949 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_9802 = torch.constant.int 1 + %7958 = torch.aten.add.Tensor %7956, %7957, %int1_9802 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_9803 = torch.constant.float 2.048000e+03 + %7959 = torch.aten.lt.Scalar %7946, %float2.048000e03_9803 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7960 = torch.aten.bitwise_not %7959 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_9804 = torch.constant.float 8.192000e+03 + %7961 = torch.aten.gt.Scalar %7946, %float8.192000e03_9804 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %7962 = torch.aten.bitwise_not %7961 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7963 = torch.aten.mul.Tensor %7960, %7962 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %7964 = torch.aten.where.self %7963, %7958, %7949 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %7965 = torch.prim.ListConstruct %7964, %7964 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_9805 = torch.constant.int -1 + %7966 = torch.aten.cat %7965, %int-1_9805 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_9806 = torch.constant.int 6 + %7967 = torch.prims.convert_element_type %7966, %int6_9806 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_9807 = torch.constant.int 1 + %7968 = torch.aten.unsqueeze %7938, %int1_9807 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_9808 = torch.constant.int 6 + %7969 = torch.prims.convert_element_type %7968, %int6_9808 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_9809 = torch.constant.int 0 + %7970 = torch.aten.unsqueeze %7967, %int0_9809 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_9810 = torch.constant.int 6 + %7971 = torch.prims.convert_element_type %7970, %int6_9810 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %7972 = torch.aten.mul.Tensor %7969, %7971 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %7973 = torch.aten.cos %7972 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_9811 = torch.constant.int 5 + %7974 = torch.prims.convert_element_type %7973, %int5_9811 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %7975 = torch.aten.sin %7972 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_9812 = torch.constant.int 5 + %7976 = torch.prims.convert_element_type %7975, %int5_9812 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_9813 = torch.constant.int 0 + %int0_9814 = torch.constant.int 0 + %int1_9815 = torch.constant.int 1 + %7977 = torch.aten.slice.Tensor %7974, %int0_9813, %int0_9814, %298, %int1_9815 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7977, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_9816 = torch.constant.int 1 + %int0_9817 = torch.constant.int 0 + %int9223372036854775807_9818 = torch.constant.int 9223372036854775807 + %int1_9819 = torch.constant.int 1 + %7978 = torch.aten.slice.Tensor %7977, %int1_9816, %int0_9817, %int9223372036854775807_9818, %int1_9819 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7978, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_9820 = torch.constant.int 0 + %int0_9821 = torch.constant.int 0 + %int1_9822 = torch.constant.int 1 + %7979 = torch.aten.slice.Tensor %7976, %int0_9820, %int0_9821, %298, %int1_9822 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7979, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_9823 = torch.constant.int 1 + %int0_9824 = torch.constant.int 0 + %int9223372036854775807_9825 = torch.constant.int 9223372036854775807 + %int1_9826 = torch.constant.int 1 + %7980 = torch.aten.slice.Tensor %7979, %int1_9823, %int0_9824, %int9223372036854775807_9825, %int1_9826 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7980, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_9827 = torch.constant.int 0 + %7981 = torch.aten.unsqueeze %7978, %int0_9827 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7981, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_9828 = torch.constant.int 1 + %int0_9829 = torch.constant.int 0 + %int9223372036854775807_9830 = torch.constant.int 9223372036854775807 + %int1_9831 = torch.constant.int 1 + %7982 = torch.aten.slice.Tensor %7981, %int1_9828, %int0_9829, %int9223372036854775807_9830, %int1_9831 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7982, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_9832 = torch.constant.int 2 + %7983 = torch.aten.unsqueeze %7982, %int2_9832 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7983, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_9833 = torch.constant.int 3 + %int0_9834 = torch.constant.int 0 + %int9223372036854775807_9835 = torch.constant.int 9223372036854775807 + %int1_9836 = torch.constant.int 1 + %7984 = torch.aten.slice.Tensor %7983, %int3_9833, %int0_9834, %int9223372036854775807_9835, %int1_9836 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7984, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_9837 = torch.constant.int 4 + %int1_9838 = torch.constant.int 1 + %int1_9839 = torch.constant.int 1 + %int1_9840 = torch.constant.int 1 + %7985 = torch.prim.ListConstruct %int4_9837, %int1_9838, %int1_9839, %int1_9840 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7986 = torch.aten.repeat %7984, %7985 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7986, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_9841 = torch.constant.int 0 + %7987 = torch.aten.unsqueeze %7980, %int0_9841 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7987, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_9842 = torch.constant.int 1 + %int0_9843 = torch.constant.int 0 + %int9223372036854775807_9844 = torch.constant.int 9223372036854775807 + %int1_9845 = torch.constant.int 1 + %7988 = torch.aten.slice.Tensor %7987, %int1_9842, %int0_9843, %int9223372036854775807_9844, %int1_9845 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %7988, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_9846 = torch.constant.int 2 + %7989 = torch.aten.unsqueeze %7988, %int2_9846 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7989, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_9847 = torch.constant.int 3 + %int0_9848 = torch.constant.int 0 + %int9223372036854775807_9849 = torch.constant.int 9223372036854775807 + %int1_9850 = torch.constant.int 1 + %7990 = torch.aten.slice.Tensor %7989, %int3_9847, %int0_9848, %int9223372036854775807_9849, %int1_9850 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %7990, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_9851 = torch.constant.int 4 + %int1_9852 = torch.constant.int 1 + %int1_9853 = torch.constant.int 1 + %int1_9854 = torch.constant.int 1 + %7991 = torch.prim.ListConstruct %int4_9851, %int1_9852, %int1_9853, %int1_9854 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7992 = torch.aten.repeat %7990, %7991 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %7992, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %7993 = torch.aten.mul.Tensor %7872, %7986 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7993, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_9855 = torch.constant.int 3 + %int0_9856 = torch.constant.int 0 + %int64_9857 = torch.constant.int 64 + %int1_9858 = torch.constant.int 1 + %7994 = torch.aten.slice.Tensor %7872, %int3_9855, %int0_9856, %int64_9857, %int1_9858 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %7994, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_9859 = torch.constant.int 3 + %int64_9860 = torch.constant.int 64 + %int9223372036854775807_9861 = torch.constant.int 9223372036854775807 + %int1_9862 = torch.constant.int 1 + %7995 = torch.aten.slice.Tensor %7872, %int3_9859, %int64_9860, %int9223372036854775807_9861, %int1_9862 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %7995, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %7996 = torch.aten.neg %7995 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %7996, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %7997 = torch.prim.ListConstruct %7996, %7994 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_9863 = torch.constant.int -1 + %7998 = torch.aten.cat %7997, %int-1_9863 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7998, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %7999 = torch.aten.mul.Tensor %7998, %7992 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7999, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_9864 = torch.constant.int 1 + %8000 = torch.aten.add.Tensor %7993, %7999, %int1_9864 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8000, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_9865 = torch.constant.int 32 + %8001 = torch.aten.mul.Scalar %arg2, %int32_9865 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8001, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int28 = torch.constant.int 28 + %int1_9866 = torch.constant.int 1 + %8002 = torch.aten.add.Scalar %8001, %int28, %int1_9866 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8002, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_9867 = torch.constant.int 2 + %8003 = torch.aten.mul.Scalar %8002, %int2_9867 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8003, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_9868 = torch.constant.int 0 + %int1_9869 = torch.constant.int 1 + %8004 = torch.aten.add.Scalar %8003, %int0_9868, %int1_9869 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8004, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %8005 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %8006 = torch.aten.view %8004, %8005 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %8006, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_9870 = torch.constant.int 4 + %int32_9871 = torch.constant.int 32 + %int8_9872 = torch.constant.int 8 + %int128_9873 = torch.constant.int 128 + %8007 = torch.prim.ListConstruct %int4_9870, %296, %int32_9871, %int8_9872, %int128_9873 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8008 = torch.aten.view %8000, %8007 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %8008, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_9874 = torch.constant.int 32 + %int8_9875 = torch.constant.int 8 + %int128_9876 = torch.constant.int 128 + %8009 = torch.prim.ListConstruct %504, %int32_9874, %int8_9875, %int128_9876 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8010 = torch.aten.view %8008, %8009 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %8010, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_9877 = torch.constant.int 1 + %int2_9878 = torch.constant.int 2 + %8011 = torch.aten.transpose.int %8010, %int1_9877, %int2_9878 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8011, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_9879 = torch.constant.int 5 + %8012 = torch.prims.convert_element_type %8011, %int5_9879 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8012, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_9880 = torch.constant.int 32 + %int2_9881 = torch.constant.int 2 + %int8_9882 = torch.constant.int 8 + %int32_9883 = torch.constant.int 32 + %int128_9884 = torch.constant.int 128 + %8013 = torch.prim.ListConstruct %297, %int32_9880, %int2_9881, %int8_9882, %int32_9883, %int128_9884 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8014 = torch.aten.view %7776, %8013 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8014, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_9885 = torch.constant.int 8 + %int32_9886 = torch.constant.int 32 + %int128_9887 = torch.constant.int 128 + %8015 = torch.prim.ListConstruct %497, %int8_9885, %int32_9886, %int128_9887 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8016 = torch.aten.view %8014, %8015 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8016, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %8017 = torch.prim.ListConstruct %8006 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_9888 = torch.constant.bool false + %8018 = torch.aten.index_put %8016, %8017, %8012, %false_9888 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8018, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_9889 = torch.constant.int 32 + %int2_9890 = torch.constant.int 2 + %int8_9891 = torch.constant.int 8 + %int32_9892 = torch.constant.int 32 + %int128_9893 = torch.constant.int 128 + %8019 = torch.prim.ListConstruct %297, %int32_9889, %int2_9890, %int8_9891, %int32_9892, %int128_9893 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8020 = torch.aten.view %8018, %8019 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8020, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_9894 = torch.constant.int 2097152 + %8021 = torch.prim.ListConstruct %297, %int2097152_9894 : (!torch.int, !torch.int) -> !torch.list + %8022 = torch.aten.view %8020, %8021 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %8022, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_9895 = torch.constant.int 32 + %int2_9896 = torch.constant.int 2 + %int8_9897 = torch.constant.int 8 + %int32_9898 = torch.constant.int 32 + %int128_9899 = torch.constant.int 128 + %8023 = torch.prim.ListConstruct %297, %int32_9895, %int2_9896, %int8_9897, %int32_9898, %int128_9899 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8024 = torch.aten.view %8022, %8023 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8024, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_9900 = torch.constant.int 8 + %int32_9901 = torch.constant.int 32 + %int128_9902 = torch.constant.int 128 + %8025 = torch.prim.ListConstruct %497, %int8_9900, %int32_9901, %int128_9902 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8026 = torch.aten.view %8024, %8025 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8026, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_9903 = torch.constant.int 32 + %8027 = torch.aten.mul.Scalar %arg2, %int32_9903 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8027, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int28_9904 = torch.constant.int 28 + %int1_9905 = torch.constant.int 1 + %8028 = torch.aten.add.Scalar %8027, %int28_9904, %int1_9905 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8028, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_9906 = torch.constant.int 2 + %8029 = torch.aten.mul.Scalar %8028, %int2_9906 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8029, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_9907 = torch.constant.int 1 + %int1_9908 = torch.constant.int 1 + %8030 = torch.aten.add.Scalar %8029, %int1_9907, %int1_9908 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8030, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %8031 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %8032 = torch.aten.view %8030, %8031 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %8032, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_9909 = torch.constant.int 4 + %int32_9910 = torch.constant.int 32 + %int8_9911 = torch.constant.int 8 + %int128_9912 = torch.constant.int 128 + %8033 = torch.prim.ListConstruct %int4_9909, %296, %int32_9910, %int8_9911, %int128_9912 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8034 = torch.aten.view %7874, %8033 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %8034, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_9913 = torch.constant.int 32 + %int8_9914 = torch.constant.int 8 + %int128_9915 = torch.constant.int 128 + %8035 = torch.prim.ListConstruct %504, %int32_9913, %int8_9914, %int128_9915 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8036 = torch.aten.view %8034, %8035 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %8036, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_9916 = torch.constant.int 1 + %int2_9917 = torch.constant.int 2 + %8037 = torch.aten.transpose.int %8036, %int1_9916, %int2_9917 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8037, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_9918 = torch.constant.int 5 + %8038 = torch.prims.convert_element_type %8037, %int5_9918 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8038, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %8039 = torch.prim.ListConstruct %8032 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_9919 = torch.constant.bool false + %8040 = torch.aten.index_put %8026, %8039, %8038, %false_9919 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8040, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_9920 = torch.constant.int 32 + %int2_9921 = torch.constant.int 2 + %int8_9922 = torch.constant.int 8 + %int32_9923 = torch.constant.int 32 + %int128_9924 = torch.constant.int 128 + %8041 = torch.prim.ListConstruct %297, %int32_9920, %int2_9921, %int8_9922, %int32_9923, %int128_9924 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8042 = torch.aten.view %8040, %8041 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8042, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_9925 = torch.constant.int 2097152 + %8043 = torch.prim.ListConstruct %297, %int2097152_9925 : (!torch.int, !torch.int) -> !torch.list + %8044 = torch.aten.view %8042, %8043 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %8044, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_9926 = torch.constant.int -2 + %8045 = torch.aten.unsqueeze %8000, %int-2_9926 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %8045, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_9927 = torch.constant.int 4 + %int8_9928 = torch.constant.int 8 + %int4_9929 = torch.constant.int 4 + %int128_9930 = torch.constant.int 128 + %8046 = torch.prim.ListConstruct %int4_9927, %298, %int8_9928, %int4_9929, %int128_9930 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_9931 = torch.constant.bool false + %8047 = torch.aten.expand %8045, %8046, %false_9931 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8047, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_9932 = torch.constant.int 0 + %8048 = torch.aten.clone %8047, %int0_9932 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8048, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_9933 = torch.constant.int 4 + %int32_9934 = torch.constant.int 32 + %int128_9935 = torch.constant.int 128 + %8049 = torch.prim.ListConstruct %int4_9933, %298, %int32_9934, %int128_9935 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8050 = torch.aten._unsafe_view %8048, %8049 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8050, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_9936 = torch.constant.int -2 + %8051 = torch.aten.unsqueeze %7874, %int-2_9936 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %8051, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_9937 = torch.constant.int 4 + %int8_9938 = torch.constant.int 8 + %int4_9939 = torch.constant.int 4 + %int128_9940 = torch.constant.int 128 + %8052 = torch.prim.ListConstruct %int4_9937, %298, %int8_9938, %int4_9939, %int128_9940 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_9941 = torch.constant.bool false + %8053 = torch.aten.expand %8051, %8052, %false_9941 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8053, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_9942 = torch.constant.int 0 + %8054 = torch.aten.clone %8053, %int0_9942 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8054, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_9943 = torch.constant.int 4 + %int32_9944 = torch.constant.int 32 + %int128_9945 = torch.constant.int 128 + %8055 = torch.prim.ListConstruct %int4_9943, %298, %int32_9944, %int128_9945 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8056 = torch.aten._unsafe_view %8054, %8055 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8056, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_9946 = torch.constant.int 1 + %int2_9947 = torch.constant.int 2 + %8057 = torch.aten.transpose.int %7937, %int1_9946, %int2_9947 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %8057, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_9948 = torch.constant.int 1 + %int2_9949 = torch.constant.int 2 + %8058 = torch.aten.transpose.int %8050, %int1_9948, %int2_9949 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %8058, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_9950 = torch.constant.int 1 + %int2_9951 = torch.constant.int 2 + %8059 = torch.aten.transpose.int %8056, %int1_9950, %int2_9951 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %8059, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_9952 = torch.constant.float 0.000000e+00 + %false_9953 = torch.constant.bool false + %none_9954 = torch.constant.none + %8060:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%8057, %8058, %8059, %float0.000000e00_9952, %false_9953, %327, %none_9954) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %8060#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_9955 = torch.constant.int 1 + %int2_9956 = torch.constant.int 2 + %8061 = torch.aten.transpose.int %8060#0, %int1_9955, %int2_9956 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8061, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_9957 = torch.constant.int 4 + %int4096_9958 = torch.constant.int 4096 + %8062 = torch.prim.ListConstruct %int4_9957, %298, %int4096_9958 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8063 = torch.aten.view %8061, %8062 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8063, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_9959 = torch.constant.int -2 + %int-1_9960 = torch.constant.int -1 + %8064 = torch.aten.transpose.int %258, %int-2_9959, %int-1_9960 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_9961 = torch.constant.int 5 + %8065 = torch.prims.convert_element_type %8064, %int5_9961 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_9962 = torch.constant.int 4096 + %8066 = torch.prim.ListConstruct %342, %int4096_9962 : (!torch.int, !torch.int) -> !torch.list + %8067 = torch.aten.view %8063, %8066 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8067, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8068 = torch.aten.mm %8067, %8065 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8068, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_9963 = torch.constant.int 4 + %int4096_9964 = torch.constant.int 4096 + %8069 = torch.prim.ListConstruct %int4_9963, %298, %int4096_9964 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8070 = torch.aten.view %8068, %8069 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8070, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_9965 = torch.constant.int 1 + %8071 = torch.aten.add.Tensor %7837, %8070, %int1_9965 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8071, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_9966 = torch.constant.int 6 + %8072 = torch.prims.convert_element_type %8071, %int6_9966 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8072, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_9967 = torch.constant.int 2 + %8073 = torch.aten.pow.Tensor_Scalar %8072, %int2_9967 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8073, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_9968 = torch.constant.int -1 + %8074 = torch.prim.ListConstruct %int-1_9968 : (!torch.int) -> !torch.list + %true_9969 = torch.constant.bool true + %none_9970 = torch.constant.none + %8075 = torch.aten.mean.dim %8073, %8074, %true_9969, %none_9970 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8075, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_9971 = torch.constant.float 9.9999997473787516E-6 + %int1_9972 = torch.constant.int 1 + %8076 = torch.aten.add.Scalar %8075, %float9.999990e-06_9971, %int1_9972 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8076, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8077 = torch.aten.rsqrt %8076 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8077, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8078 = torch.aten.mul.Tensor %8072, %8077 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8078, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_9973 = torch.constant.int 5 + %8079 = torch.prims.convert_element_type %8078, %int5_9973 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8079, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %8080 = torch.aten.mul.Tensor %259, %8079 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8080, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_9974 = torch.constant.int 5 + %8081 = torch.prims.convert_element_type %8080, %int5_9974 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8081, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_9975 = torch.constant.int -2 + %int-1_9976 = torch.constant.int -1 + %8082 = torch.aten.transpose.int %260, %int-2_9975, %int-1_9976 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_9977 = torch.constant.int 5 + %8083 = torch.prims.convert_element_type %8082, %int5_9977 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_9978 = torch.constant.int 4096 + %8084 = torch.prim.ListConstruct %342, %int4096_9978 : (!torch.int, !torch.int) -> !torch.list + %8085 = torch.aten.view %8081, %8084 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8085, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8086 = torch.aten.mm %8085, %8083 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %8086, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_9979 = torch.constant.int 4 + %int14336_9980 = torch.constant.int 14336 + %8087 = torch.prim.ListConstruct %int4_9979, %298, %int14336_9980 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8088 = torch.aten.view %8086, %8087 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8088, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %8089 = torch.aten.silu %8088 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8089, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_9981 = torch.constant.int -2 + %int-1_9982 = torch.constant.int -1 + %8090 = torch.aten.transpose.int %261, %int-2_9981, %int-1_9982 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_9983 = torch.constant.int 5 + %8091 = torch.prims.convert_element_type %8090, %int5_9983 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_9984 = torch.constant.int 4096 + %8092 = torch.prim.ListConstruct %342, %int4096_9984 : (!torch.int, !torch.int) -> !torch.list + %8093 = torch.aten.view %8081, %8092 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8093, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8094 = torch.aten.mm %8093, %8091 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %8094, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_9985 = torch.constant.int 4 + %int14336_9986 = torch.constant.int 14336 + %8095 = torch.prim.ListConstruct %int4_9985, %298, %int14336_9986 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8096 = torch.aten.view %8094, %8095 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8096, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %8097 = torch.aten.mul.Tensor %8089, %8096 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8097, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_9987 = torch.constant.int -2 + %int-1_9988 = torch.constant.int -1 + %8098 = torch.aten.transpose.int %262, %int-2_9987, %int-1_9988 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_9989 = torch.constant.int 5 + %8099 = torch.prims.convert_element_type %8098, %int5_9989 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_9990 = torch.constant.int 14336 + %8100 = torch.prim.ListConstruct %342, %int14336_9990 : (!torch.int, !torch.int) -> !torch.list + %8101 = torch.aten.view %8097, %8100 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %8101, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %8102 = torch.aten.mm %8101, %8099 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8102, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_9991 = torch.constant.int 4 + %int4096_9992 = torch.constant.int 4096 + %8103 = torch.prim.ListConstruct %int4_9991, %298, %int4096_9992 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8104 = torch.aten.view %8102, %8103 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8104, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_9993 = torch.constant.int 1 + %8105 = torch.aten.add.Tensor %8071, %8104, %int1_9993 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8105, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_9994 = torch.constant.int 6 + %8106 = torch.prims.convert_element_type %8105, %int6_9994 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8106, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_9995 = torch.constant.int 2 + %8107 = torch.aten.pow.Tensor_Scalar %8106, %int2_9995 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8107, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_9996 = torch.constant.int -1 + %8108 = torch.prim.ListConstruct %int-1_9996 : (!torch.int) -> !torch.list + %true_9997 = torch.constant.bool true + %none_9998 = torch.constant.none + %8109 = torch.aten.mean.dim %8107, %8108, %true_9997, %none_9998 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8109, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_9999 = torch.constant.float 9.9999997473787516E-6 + %int1_10000 = torch.constant.int 1 + %8110 = torch.aten.add.Scalar %8109, %float9.999990e-06_9999, %int1_10000 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8110, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8111 = torch.aten.rsqrt %8110 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8111, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8112 = torch.aten.mul.Tensor %8106, %8111 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8112, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_10001 = torch.constant.int 5 + %8113 = torch.prims.convert_element_type %8112, %int5_10001 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8113, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %8114 = torch.aten.mul.Tensor %263, %8113 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8114, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_10002 = torch.constant.int 5 + %8115 = torch.prims.convert_element_type %8114, %int5_10002 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8115, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_10003 = torch.constant.int -2 + %int-1_10004 = torch.constant.int -1 + %8116 = torch.aten.transpose.int %264, %int-2_10003, %int-1_10004 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_10005 = torch.constant.int 5 + %8117 = torch.prims.convert_element_type %8116, %int5_10005 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_10006 = torch.constant.int 4096 + %8118 = torch.prim.ListConstruct %342, %int4096_10006 : (!torch.int, !torch.int) -> !torch.list + %8119 = torch.aten.view %8115, %8118 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8119, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8120 = torch.aten.mm %8119, %8117 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8120, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_10007 = torch.constant.int 4 + %int4096_10008 = torch.constant.int 4096 + %8121 = torch.prim.ListConstruct %int4_10007, %298, %int4096_10008 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8122 = torch.aten.view %8120, %8121 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8122, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_10009 = torch.constant.int -2 + %int-1_10010 = torch.constant.int -1 + %8123 = torch.aten.transpose.int %265, %int-2_10009, %int-1_10010 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_10011 = torch.constant.int 5 + %8124 = torch.prims.convert_element_type %8123, %int5_10011 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_10012 = torch.constant.int 4096 + %8125 = torch.prim.ListConstruct %342, %int4096_10012 : (!torch.int, !torch.int) -> !torch.list + %8126 = torch.aten.view %8115, %8125 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8126, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8127 = torch.aten.mm %8126, %8124 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %8127, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_10013 = torch.constant.int 4 + %int1024_10014 = torch.constant.int 1024 + %8128 = torch.prim.ListConstruct %int4_10013, %298, %int1024_10014 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8129 = torch.aten.view %8127, %8128 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %8129, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_10015 = torch.constant.int -2 + %int-1_10016 = torch.constant.int -1 + %8130 = torch.aten.transpose.int %266, %int-2_10015, %int-1_10016 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_10017 = torch.constant.int 5 + %8131 = torch.prims.convert_element_type %8130, %int5_10017 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_10018 = torch.constant.int 4096 + %8132 = torch.prim.ListConstruct %342, %int4096_10018 : (!torch.int, !torch.int) -> !torch.list + %8133 = torch.aten.view %8115, %8132 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8133, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8134 = torch.aten.mm %8133, %8131 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %8134, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_10019 = torch.constant.int 4 + %int1024_10020 = torch.constant.int 1024 + %8135 = torch.prim.ListConstruct %int4_10019, %298, %int1024_10020 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8136 = torch.aten.view %8134, %8135 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %8136, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_10021 = torch.constant.int 4 + %int32_10022 = torch.constant.int 32 + %int128_10023 = torch.constant.int 128 + %8137 = torch.prim.ListConstruct %int4_10021, %298, %int32_10022, %int128_10023 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8138 = torch.aten.view %8122, %8137 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8138, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_10024 = torch.constant.int 4 + %int8_10025 = torch.constant.int 8 + %int128_10026 = torch.constant.int 128 + %8139 = torch.prim.ListConstruct %int4_10024, %298, %int8_10025, %int128_10026 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8140 = torch.aten.view %8129, %8139 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8140, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_10027 = torch.constant.int 4 + %int8_10028 = torch.constant.int 8 + %int128_10029 = torch.constant.int 128 + %8141 = torch.prim.ListConstruct %int4_10027, %298, %int8_10028, %int128_10029 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8142 = torch.aten.view %8136, %8141 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8142, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_10030 = torch.constant.int 131072 + %none_10031 = torch.constant.none + %none_10032 = torch.constant.none + %cpu_10033 = torch.constant.device "cpu" + %false_10034 = torch.constant.bool false + %8143 = torch.aten.arange %int131072_10030, %none_10031, %none_10032, %cpu_10033, %false_10034 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_10035 = torch.constant.int 0 + %int128_10036 = torch.constant.int 128 + %int2_10037 = torch.constant.int 2 + %int4_10038 = torch.constant.int 4 + %none_10039 = torch.constant.none + %cpu_10040 = torch.constant.device "cpu" + %false_10041 = torch.constant.bool false + %8144 = torch.aten.arange.start_step %int0_10035, %int128_10036, %int2_10037, %int4_10038, %none_10039, %cpu_10040, %false_10041 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_10042 = torch.constant.int 6 + %8145 = torch.prims.convert_element_type %8144, %int6_10042 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_10043 = torch.constant.int 128 + %8146 = torch.aten.div.Scalar %8145, %int128_10043 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_10044 = torch.constant.float 5.000000e+05 + %8147 = torch.aten.pow.Scalar %float5.000000e05_10044, %8146 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8148 = torch.aten.reciprocal %8147 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_10045 = torch.constant.float 1.000000e+00 + %8149 = torch.aten.mul.Scalar %8148, %float1.000000e00_10045 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %8150 = torch.aten.reciprocal %8149 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_10046 = torch.constant.float 6.2831853071795862 + %8151 = torch.aten.mul.Scalar %8150, %float6.283190e00_10046 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_10047 = torch.constant.float 8.192000e+03 + %8152 = torch.aten.gt.Scalar %8151, %float8.192000e03_10047 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_10048 = torch.constant.int 8 + %8153 = torch.aten.div.Scalar %8149, %int8_10048 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %8154 = torch.aten.where.self %8152, %8153, %8149 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8155 = torch.aten.reciprocal %8151 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_10049 = torch.constant.int 8192 + %8156 = torch.aten.mul.Scalar %8155, %int8192_10049 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_10050 = torch.constant.int 1 + %int1_10051 = torch.constant.int 1 + %8157 = torch.aten.sub.Scalar %8156, %int1_10050, %int1_10051 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_10052 = torch.constant.int 3 + %8158 = torch.aten.div.Scalar %8157, %int3_10052 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_10053 = torch.constant.int 1 + %int1_10054 = torch.constant.int 1 + %8159 = torch.aten.rsub.Scalar %8158, %int1_10053, %int1_10054 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %8160 = torch.aten.mul.Tensor %8159, %8154 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_10055 = torch.constant.int 8 + %8161 = torch.aten.div.Scalar %8160, %int8_10055 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %8162 = torch.aten.mul.Tensor %8158, %8154 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_10056 = torch.constant.int 1 + %8163 = torch.aten.add.Tensor %8161, %8162, %int1_10056 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_10057 = torch.constant.float 2.048000e+03 + %8164 = torch.aten.lt.Scalar %8151, %float2.048000e03_10057 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %8165 = torch.aten.bitwise_not %8164 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_10058 = torch.constant.float 8.192000e+03 + %8166 = torch.aten.gt.Scalar %8151, %float8.192000e03_10058 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %8167 = torch.aten.bitwise_not %8166 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %8168 = torch.aten.mul.Tensor %8165, %8167 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %8169 = torch.aten.where.self %8168, %8163, %8154 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8170 = torch.prim.ListConstruct %8169, %8169 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_10059 = torch.constant.int -1 + %8171 = torch.aten.cat %8170, %int-1_10059 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_10060 = torch.constant.int 6 + %8172 = torch.prims.convert_element_type %8171, %int6_10060 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_10061 = torch.constant.int 1 + %8173 = torch.aten.unsqueeze %8143, %int1_10061 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_10062 = torch.constant.int 6 + %8174 = torch.prims.convert_element_type %8173, %int6_10062 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_10063 = torch.constant.int 0 + %8175 = torch.aten.unsqueeze %8172, %int0_10063 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_10064 = torch.constant.int 6 + %8176 = torch.prims.convert_element_type %8175, %int6_10064 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %8177 = torch.aten.mul.Tensor %8174, %8176 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %8178 = torch.aten.cos %8177 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_10065 = torch.constant.int 5 + %8179 = torch.prims.convert_element_type %8178, %int5_10065 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %8180 = torch.aten.sin %8177 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_10066 = torch.constant.int 5 + %8181 = torch.prims.convert_element_type %8180, %int5_10066 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_10067 = torch.constant.int 0 + %int0_10068 = torch.constant.int 0 + %int1_10069 = torch.constant.int 1 + %8182 = torch.aten.slice.Tensor %8179, %int0_10067, %int0_10068, %298, %int1_10069 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8182, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_10070 = torch.constant.int 1 + %int0_10071 = torch.constant.int 0 + %int9223372036854775807_10072 = torch.constant.int 9223372036854775807 + %int1_10073 = torch.constant.int 1 + %8183 = torch.aten.slice.Tensor %8182, %int1_10070, %int0_10071, %int9223372036854775807_10072, %int1_10073 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8183, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_10074 = torch.constant.int 0 + %int0_10075 = torch.constant.int 0 + %int1_10076 = torch.constant.int 1 + %8184 = torch.aten.slice.Tensor %8181, %int0_10074, %int0_10075, %298, %int1_10076 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8184, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_10077 = torch.constant.int 1 + %int0_10078 = torch.constant.int 0 + %int9223372036854775807_10079 = torch.constant.int 9223372036854775807 + %int1_10080 = torch.constant.int 1 + %8185 = torch.aten.slice.Tensor %8184, %int1_10077, %int0_10078, %int9223372036854775807_10079, %int1_10080 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8185, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_10081 = torch.constant.int 0 + %8186 = torch.aten.unsqueeze %8183, %int0_10081 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8186, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_10082 = torch.constant.int 1 + %int0_10083 = torch.constant.int 0 + %int9223372036854775807_10084 = torch.constant.int 9223372036854775807 + %int1_10085 = torch.constant.int 1 + %8187 = torch.aten.slice.Tensor %8186, %int1_10082, %int0_10083, %int9223372036854775807_10084, %int1_10085 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8187, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_10086 = torch.constant.int 2 + %8188 = torch.aten.unsqueeze %8187, %int2_10086 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8188, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_10087 = torch.constant.int 3 + %int0_10088 = torch.constant.int 0 + %int9223372036854775807_10089 = torch.constant.int 9223372036854775807 + %int1_10090 = torch.constant.int 1 + %8189 = torch.aten.slice.Tensor %8188, %int3_10087, %int0_10088, %int9223372036854775807_10089, %int1_10090 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8189, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_10091 = torch.constant.int 4 + %int1_10092 = torch.constant.int 1 + %int1_10093 = torch.constant.int 1 + %int1_10094 = torch.constant.int 1 + %8190 = torch.prim.ListConstruct %int4_10091, %int1_10092, %int1_10093, %int1_10094 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8191 = torch.aten.repeat %8189, %8190 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %8191, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_10095 = torch.constant.int 0 + %8192 = torch.aten.unsqueeze %8185, %int0_10095 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8192, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_10096 = torch.constant.int 1 + %int0_10097 = torch.constant.int 0 + %int9223372036854775807_10098 = torch.constant.int 9223372036854775807 + %int1_10099 = torch.constant.int 1 + %8193 = torch.aten.slice.Tensor %8192, %int1_10096, %int0_10097, %int9223372036854775807_10098, %int1_10099 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8193, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_10100 = torch.constant.int 2 + %8194 = torch.aten.unsqueeze %8193, %int2_10100 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8194, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_10101 = torch.constant.int 3 + %int0_10102 = torch.constant.int 0 + %int9223372036854775807_10103 = torch.constant.int 9223372036854775807 + %int1_10104 = torch.constant.int 1 + %8195 = torch.aten.slice.Tensor %8194, %int3_10101, %int0_10102, %int9223372036854775807_10103, %int1_10104 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8195, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_10105 = torch.constant.int 4 + %int1_10106 = torch.constant.int 1 + %int1_10107 = torch.constant.int 1 + %int1_10108 = torch.constant.int 1 + %8196 = torch.prim.ListConstruct %int4_10105, %int1_10106, %int1_10107, %int1_10108 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8197 = torch.aten.repeat %8195, %8196 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %8197, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %8198 = torch.aten.mul.Tensor %8138, %8191 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8198, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_10109 = torch.constant.int 3 + %int0_10110 = torch.constant.int 0 + %int64_10111 = torch.constant.int 64 + %int1_10112 = torch.constant.int 1 + %8199 = torch.aten.slice.Tensor %8138, %int3_10109, %int0_10110, %int64_10111, %int1_10112 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %8199, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_10113 = torch.constant.int 3 + %int64_10114 = torch.constant.int 64 + %int9223372036854775807_10115 = torch.constant.int 9223372036854775807 + %int1_10116 = torch.constant.int 1 + %8200 = torch.aten.slice.Tensor %8138, %int3_10113, %int64_10114, %int9223372036854775807_10115, %int1_10116 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %8200, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %8201 = torch.aten.neg %8200 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %8201, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %8202 = torch.prim.ListConstruct %8201, %8199 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_10117 = torch.constant.int -1 + %8203 = torch.aten.cat %8202, %int-1_10117 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8203, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %8204 = torch.aten.mul.Tensor %8203, %8197 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8204, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_10118 = torch.constant.int 1 + %8205 = torch.aten.add.Tensor %8198, %8204, %int1_10118 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8205, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_10119 = torch.constant.int 131072 + %none_10120 = torch.constant.none + %none_10121 = torch.constant.none + %cpu_10122 = torch.constant.device "cpu" + %false_10123 = torch.constant.bool false + %8206 = torch.aten.arange %int131072_10119, %none_10120, %none_10121, %cpu_10122, %false_10123 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_10124 = torch.constant.int 0 + %int128_10125 = torch.constant.int 128 + %int2_10126 = torch.constant.int 2 + %int4_10127 = torch.constant.int 4 + %none_10128 = torch.constant.none + %cpu_10129 = torch.constant.device "cpu" + %false_10130 = torch.constant.bool false + %8207 = torch.aten.arange.start_step %int0_10124, %int128_10125, %int2_10126, %int4_10127, %none_10128, %cpu_10129, %false_10130 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_10131 = torch.constant.int 6 + %8208 = torch.prims.convert_element_type %8207, %int6_10131 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_10132 = torch.constant.int 128 + %8209 = torch.aten.div.Scalar %8208, %int128_10132 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_10133 = torch.constant.float 5.000000e+05 + %8210 = torch.aten.pow.Scalar %float5.000000e05_10133, %8209 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8211 = torch.aten.reciprocal %8210 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_10134 = torch.constant.float 1.000000e+00 + %8212 = torch.aten.mul.Scalar %8211, %float1.000000e00_10134 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %8213 = torch.aten.reciprocal %8212 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_10135 = torch.constant.float 6.2831853071795862 + %8214 = torch.aten.mul.Scalar %8213, %float6.283190e00_10135 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_10136 = torch.constant.float 8.192000e+03 + %8215 = torch.aten.gt.Scalar %8214, %float8.192000e03_10136 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_10137 = torch.constant.int 8 + %8216 = torch.aten.div.Scalar %8212, %int8_10137 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %8217 = torch.aten.where.self %8215, %8216, %8212 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8218 = torch.aten.reciprocal %8214 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_10138 = torch.constant.int 8192 + %8219 = torch.aten.mul.Scalar %8218, %int8192_10138 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_10139 = torch.constant.int 1 + %int1_10140 = torch.constant.int 1 + %8220 = torch.aten.sub.Scalar %8219, %int1_10139, %int1_10140 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_10141 = torch.constant.int 3 + %8221 = torch.aten.div.Scalar %8220, %int3_10141 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_10142 = torch.constant.int 1 + %int1_10143 = torch.constant.int 1 + %8222 = torch.aten.rsub.Scalar %8221, %int1_10142, %int1_10143 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %8223 = torch.aten.mul.Tensor %8222, %8217 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_10144 = torch.constant.int 8 + %8224 = torch.aten.div.Scalar %8223, %int8_10144 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %8225 = torch.aten.mul.Tensor %8221, %8217 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_10145 = torch.constant.int 1 + %8226 = torch.aten.add.Tensor %8224, %8225, %int1_10145 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_10146 = torch.constant.float 2.048000e+03 + %8227 = torch.aten.lt.Scalar %8214, %float2.048000e03_10146 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %8228 = torch.aten.bitwise_not %8227 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_10147 = torch.constant.float 8.192000e+03 + %8229 = torch.aten.gt.Scalar %8214, %float8.192000e03_10147 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %8230 = torch.aten.bitwise_not %8229 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %8231 = torch.aten.mul.Tensor %8228, %8230 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %8232 = torch.aten.where.self %8231, %8226, %8217 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8233 = torch.prim.ListConstruct %8232, %8232 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_10148 = torch.constant.int -1 + %8234 = torch.aten.cat %8233, %int-1_10148 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_10149 = torch.constant.int 6 + %8235 = torch.prims.convert_element_type %8234, %int6_10149 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_10150 = torch.constant.int 1 + %8236 = torch.aten.unsqueeze %8206, %int1_10150 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_10151 = torch.constant.int 6 + %8237 = torch.prims.convert_element_type %8236, %int6_10151 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_10152 = torch.constant.int 0 + %8238 = torch.aten.unsqueeze %8235, %int0_10152 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_10153 = torch.constant.int 6 + %8239 = torch.prims.convert_element_type %8238, %int6_10153 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %8240 = torch.aten.mul.Tensor %8237, %8239 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %8241 = torch.aten.cos %8240 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_10154 = torch.constant.int 5 + %8242 = torch.prims.convert_element_type %8241, %int5_10154 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %8243 = torch.aten.sin %8240 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_10155 = torch.constant.int 5 + %8244 = torch.prims.convert_element_type %8243, %int5_10155 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_10156 = torch.constant.int 0 + %int0_10157 = torch.constant.int 0 + %int1_10158 = torch.constant.int 1 + %8245 = torch.aten.slice.Tensor %8242, %int0_10156, %int0_10157, %298, %int1_10158 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8245, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_10159 = torch.constant.int 1 + %int0_10160 = torch.constant.int 0 + %int9223372036854775807_10161 = torch.constant.int 9223372036854775807 + %int1_10162 = torch.constant.int 1 + %8246 = torch.aten.slice.Tensor %8245, %int1_10159, %int0_10160, %int9223372036854775807_10161, %int1_10162 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8246, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_10163 = torch.constant.int 0 + %int0_10164 = torch.constant.int 0 + %int1_10165 = torch.constant.int 1 + %8247 = torch.aten.slice.Tensor %8244, %int0_10163, %int0_10164, %298, %int1_10165 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8247, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_10166 = torch.constant.int 1 + %int0_10167 = torch.constant.int 0 + %int9223372036854775807_10168 = torch.constant.int 9223372036854775807 + %int1_10169 = torch.constant.int 1 + %8248 = torch.aten.slice.Tensor %8247, %int1_10166, %int0_10167, %int9223372036854775807_10168, %int1_10169 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8248, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_10170 = torch.constant.int 0 + %8249 = torch.aten.unsqueeze %8246, %int0_10170 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8249, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_10171 = torch.constant.int 1 + %int0_10172 = torch.constant.int 0 + %int9223372036854775807_10173 = torch.constant.int 9223372036854775807 + %int1_10174 = torch.constant.int 1 + %8250 = torch.aten.slice.Tensor %8249, %int1_10171, %int0_10172, %int9223372036854775807_10173, %int1_10174 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8250, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_10175 = torch.constant.int 2 + %8251 = torch.aten.unsqueeze %8250, %int2_10175 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8251, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_10176 = torch.constant.int 3 + %int0_10177 = torch.constant.int 0 + %int9223372036854775807_10178 = torch.constant.int 9223372036854775807 + %int1_10179 = torch.constant.int 1 + %8252 = torch.aten.slice.Tensor %8251, %int3_10176, %int0_10177, %int9223372036854775807_10178, %int1_10179 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8252, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_10180 = torch.constant.int 4 + %int1_10181 = torch.constant.int 1 + %int1_10182 = torch.constant.int 1 + %int1_10183 = torch.constant.int 1 + %8253 = torch.prim.ListConstruct %int4_10180, %int1_10181, %int1_10182, %int1_10183 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8254 = torch.aten.repeat %8252, %8253 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %8254, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_10184 = torch.constant.int 0 + %8255 = torch.aten.unsqueeze %8248, %int0_10184 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8255, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_10185 = torch.constant.int 1 + %int0_10186 = torch.constant.int 0 + %int9223372036854775807_10187 = torch.constant.int 9223372036854775807 + %int1_10188 = torch.constant.int 1 + %8256 = torch.aten.slice.Tensor %8255, %int1_10185, %int0_10186, %int9223372036854775807_10187, %int1_10188 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8256, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_10189 = torch.constant.int 2 + %8257 = torch.aten.unsqueeze %8256, %int2_10189 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8257, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_10190 = torch.constant.int 3 + %int0_10191 = torch.constant.int 0 + %int9223372036854775807_10192 = torch.constant.int 9223372036854775807 + %int1_10193 = torch.constant.int 1 + %8258 = torch.aten.slice.Tensor %8257, %int3_10190, %int0_10191, %int9223372036854775807_10192, %int1_10193 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8258, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_10194 = torch.constant.int 4 + %int1_10195 = torch.constant.int 1 + %int1_10196 = torch.constant.int 1 + %int1_10197 = torch.constant.int 1 + %8259 = torch.prim.ListConstruct %int4_10194, %int1_10195, %int1_10196, %int1_10197 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8260 = torch.aten.repeat %8258, %8259 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %8260, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %8261 = torch.aten.mul.Tensor %8140, %8254 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8261, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_10198 = torch.constant.int 3 + %int0_10199 = torch.constant.int 0 + %int64_10200 = torch.constant.int 64 + %int1_10201 = torch.constant.int 1 + %8262 = torch.aten.slice.Tensor %8140, %int3_10198, %int0_10199, %int64_10200, %int1_10201 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %8262, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_10202 = torch.constant.int 3 + %int64_10203 = torch.constant.int 64 + %int9223372036854775807_10204 = torch.constant.int 9223372036854775807 + %int1_10205 = torch.constant.int 1 + %8263 = torch.aten.slice.Tensor %8140, %int3_10202, %int64_10203, %int9223372036854775807_10204, %int1_10205 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %8263, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %8264 = torch.aten.neg %8263 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %8264, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %8265 = torch.prim.ListConstruct %8264, %8262 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_10206 = torch.constant.int -1 + %8266 = torch.aten.cat %8265, %int-1_10206 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8266, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %8267 = torch.aten.mul.Tensor %8266, %8260 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8267, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_10207 = torch.constant.int 1 + %8268 = torch.aten.add.Tensor %8261, %8267, %int1_10207 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8268, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_10208 = torch.constant.int 32 + %8269 = torch.aten.mul.Scalar %arg2, %int32_10208 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8269, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int29 = torch.constant.int 29 + %int1_10209 = torch.constant.int 1 + %8270 = torch.aten.add.Scalar %8269, %int29, %int1_10209 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8270, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_10210 = torch.constant.int 2 + %8271 = torch.aten.mul.Scalar %8270, %int2_10210 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8271, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_10211 = torch.constant.int 0 + %int1_10212 = torch.constant.int 1 + %8272 = torch.aten.add.Scalar %8271, %int0_10211, %int1_10212 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8272, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %8273 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %8274 = torch.aten.view %8272, %8273 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %8274, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_10213 = torch.constant.int 4 + %int32_10214 = torch.constant.int 32 + %int8_10215 = torch.constant.int 8 + %int128_10216 = torch.constant.int 128 + %8275 = torch.prim.ListConstruct %int4_10213, %296, %int32_10214, %int8_10215, %int128_10216 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8276 = torch.aten.view %8268, %8275 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %8276, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_10217 = torch.constant.int 32 + %int8_10218 = torch.constant.int 8 + %int128_10219 = torch.constant.int 128 + %8277 = torch.prim.ListConstruct %504, %int32_10217, %int8_10218, %int128_10219 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8278 = torch.aten.view %8276, %8277 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %8278, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_10220 = torch.constant.int 1 + %int2_10221 = torch.constant.int 2 + %8279 = torch.aten.transpose.int %8278, %int1_10220, %int2_10221 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8279, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_10222 = torch.constant.int 5 + %8280 = torch.prims.convert_element_type %8279, %int5_10222 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8280, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_10223 = torch.constant.int 32 + %int2_10224 = torch.constant.int 2 + %int8_10225 = torch.constant.int 8 + %int32_10226 = torch.constant.int 32 + %int128_10227 = torch.constant.int 128 + %8281 = torch.prim.ListConstruct %297, %int32_10223, %int2_10224, %int8_10225, %int32_10226, %int128_10227 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8282 = torch.aten.view %8044, %8281 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8282, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_10228 = torch.constant.int 8 + %int32_10229 = torch.constant.int 32 + %int128_10230 = torch.constant.int 128 + %8283 = torch.prim.ListConstruct %497, %int8_10228, %int32_10229, %int128_10230 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8284 = torch.aten.view %8282, %8283 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8284, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %8285 = torch.prim.ListConstruct %8274 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_10231 = torch.constant.bool false + %8286 = torch.aten.index_put %8284, %8285, %8280, %false_10231 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8286, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_10232 = torch.constant.int 32 + %int2_10233 = torch.constant.int 2 + %int8_10234 = torch.constant.int 8 + %int32_10235 = torch.constant.int 32 + %int128_10236 = torch.constant.int 128 + %8287 = torch.prim.ListConstruct %297, %int32_10232, %int2_10233, %int8_10234, %int32_10235, %int128_10236 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8288 = torch.aten.view %8286, %8287 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8288, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_10237 = torch.constant.int 2097152 + %8289 = torch.prim.ListConstruct %297, %int2097152_10237 : (!torch.int, !torch.int) -> !torch.list + %8290 = torch.aten.view %8288, %8289 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %8290, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_10238 = torch.constant.int 32 + %int2_10239 = torch.constant.int 2 + %int8_10240 = torch.constant.int 8 + %int32_10241 = torch.constant.int 32 + %int128_10242 = torch.constant.int 128 + %8291 = torch.prim.ListConstruct %297, %int32_10238, %int2_10239, %int8_10240, %int32_10241, %int128_10242 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8292 = torch.aten.view %8290, %8291 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8292, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_10243 = torch.constant.int 8 + %int32_10244 = torch.constant.int 32 + %int128_10245 = torch.constant.int 128 + %8293 = torch.prim.ListConstruct %497, %int8_10243, %int32_10244, %int128_10245 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8294 = torch.aten.view %8292, %8293 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8294, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_10246 = torch.constant.int 32 + %8295 = torch.aten.mul.Scalar %arg2, %int32_10246 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8295, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int29_10247 = torch.constant.int 29 + %int1_10248 = torch.constant.int 1 + %8296 = torch.aten.add.Scalar %8295, %int29_10247, %int1_10248 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8296, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_10249 = torch.constant.int 2 + %8297 = torch.aten.mul.Scalar %8296, %int2_10249 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8297, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_10250 = torch.constant.int 1 + %int1_10251 = torch.constant.int 1 + %8298 = torch.aten.add.Scalar %8297, %int1_10250, %int1_10251 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8298, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %8299 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %8300 = torch.aten.view %8298, %8299 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %8300, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_10252 = torch.constant.int 4 + %int32_10253 = torch.constant.int 32 + %int8_10254 = torch.constant.int 8 + %int128_10255 = torch.constant.int 128 + %8301 = torch.prim.ListConstruct %int4_10252, %296, %int32_10253, %int8_10254, %int128_10255 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8302 = torch.aten.view %8142, %8301 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %8302, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_10256 = torch.constant.int 32 + %int8_10257 = torch.constant.int 8 + %int128_10258 = torch.constant.int 128 + %8303 = torch.prim.ListConstruct %504, %int32_10256, %int8_10257, %int128_10258 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8304 = torch.aten.view %8302, %8303 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %8304, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_10259 = torch.constant.int 1 + %int2_10260 = torch.constant.int 2 + %8305 = torch.aten.transpose.int %8304, %int1_10259, %int2_10260 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8305, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_10261 = torch.constant.int 5 + %8306 = torch.prims.convert_element_type %8305, %int5_10261 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8306, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %8307 = torch.prim.ListConstruct %8300 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_10262 = torch.constant.bool false + %8308 = torch.aten.index_put %8294, %8307, %8306, %false_10262 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8308, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_10263 = torch.constant.int 32 + %int2_10264 = torch.constant.int 2 + %int8_10265 = torch.constant.int 8 + %int32_10266 = torch.constant.int 32 + %int128_10267 = torch.constant.int 128 + %8309 = torch.prim.ListConstruct %297, %int32_10263, %int2_10264, %int8_10265, %int32_10266, %int128_10267 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8310 = torch.aten.view %8308, %8309 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8310, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_10268 = torch.constant.int 2097152 + %8311 = torch.prim.ListConstruct %297, %int2097152_10268 : (!torch.int, !torch.int) -> !torch.list + %8312 = torch.aten.view %8310, %8311 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %8312, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_10269 = torch.constant.int -2 + %8313 = torch.aten.unsqueeze %8268, %int-2_10269 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %8313, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_10270 = torch.constant.int 4 + %int8_10271 = torch.constant.int 8 + %int4_10272 = torch.constant.int 4 + %int128_10273 = torch.constant.int 128 + %8314 = torch.prim.ListConstruct %int4_10270, %298, %int8_10271, %int4_10272, %int128_10273 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_10274 = torch.constant.bool false + %8315 = torch.aten.expand %8313, %8314, %false_10274 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8315, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_10275 = torch.constant.int 0 + %8316 = torch.aten.clone %8315, %int0_10275 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8316, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_10276 = torch.constant.int 4 + %int32_10277 = torch.constant.int 32 + %int128_10278 = torch.constant.int 128 + %8317 = torch.prim.ListConstruct %int4_10276, %298, %int32_10277, %int128_10278 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8318 = torch.aten._unsafe_view %8316, %8317 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8318, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_10279 = torch.constant.int -2 + %8319 = torch.aten.unsqueeze %8142, %int-2_10279 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %8319, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_10280 = torch.constant.int 4 + %int8_10281 = torch.constant.int 8 + %int4_10282 = torch.constant.int 4 + %int128_10283 = torch.constant.int 128 + %8320 = torch.prim.ListConstruct %int4_10280, %298, %int8_10281, %int4_10282, %int128_10283 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_10284 = torch.constant.bool false + %8321 = torch.aten.expand %8319, %8320, %false_10284 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8321, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_10285 = torch.constant.int 0 + %8322 = torch.aten.clone %8321, %int0_10285 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8322, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_10286 = torch.constant.int 4 + %int32_10287 = torch.constant.int 32 + %int128_10288 = torch.constant.int 128 + %8323 = torch.prim.ListConstruct %int4_10286, %298, %int32_10287, %int128_10288 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8324 = torch.aten._unsafe_view %8322, %8323 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8324, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_10289 = torch.constant.int 1 + %int2_10290 = torch.constant.int 2 + %8325 = torch.aten.transpose.int %8205, %int1_10289, %int2_10290 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %8325, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_10291 = torch.constant.int 1 + %int2_10292 = torch.constant.int 2 + %8326 = torch.aten.transpose.int %8318, %int1_10291, %int2_10292 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %8326, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_10293 = torch.constant.int 1 + %int2_10294 = torch.constant.int 2 + %8327 = torch.aten.transpose.int %8324, %int1_10293, %int2_10294 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %8327, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_10295 = torch.constant.float 0.000000e+00 + %false_10296 = torch.constant.bool false + %none_10297 = torch.constant.none + %8328:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%8325, %8326, %8327, %float0.000000e00_10295, %false_10296, %327, %none_10297) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %8328#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_10298 = torch.constant.int 1 + %int2_10299 = torch.constant.int 2 + %8329 = torch.aten.transpose.int %8328#0, %int1_10298, %int2_10299 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8329, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_10300 = torch.constant.int 4 + %int4096_10301 = torch.constant.int 4096 + %8330 = torch.prim.ListConstruct %int4_10300, %298, %int4096_10301 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8331 = torch.aten.view %8329, %8330 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8331, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_10302 = torch.constant.int -2 + %int-1_10303 = torch.constant.int -1 + %8332 = torch.aten.transpose.int %267, %int-2_10302, %int-1_10303 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_10304 = torch.constant.int 5 + %8333 = torch.prims.convert_element_type %8332, %int5_10304 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_10305 = torch.constant.int 4096 + %8334 = torch.prim.ListConstruct %342, %int4096_10305 : (!torch.int, !torch.int) -> !torch.list + %8335 = torch.aten.view %8331, %8334 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8335, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8336 = torch.aten.mm %8335, %8333 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8336, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_10306 = torch.constant.int 4 + %int4096_10307 = torch.constant.int 4096 + %8337 = torch.prim.ListConstruct %int4_10306, %298, %int4096_10307 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8338 = torch.aten.view %8336, %8337 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8338, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_10308 = torch.constant.int 1 + %8339 = torch.aten.add.Tensor %8105, %8338, %int1_10308 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8339, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_10309 = torch.constant.int 6 + %8340 = torch.prims.convert_element_type %8339, %int6_10309 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8340, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_10310 = torch.constant.int 2 + %8341 = torch.aten.pow.Tensor_Scalar %8340, %int2_10310 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8341, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_10311 = torch.constant.int -1 + %8342 = torch.prim.ListConstruct %int-1_10311 : (!torch.int) -> !torch.list + %true_10312 = torch.constant.bool true + %none_10313 = torch.constant.none + %8343 = torch.aten.mean.dim %8341, %8342, %true_10312, %none_10313 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8343, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_10314 = torch.constant.float 9.9999997473787516E-6 + %int1_10315 = torch.constant.int 1 + %8344 = torch.aten.add.Scalar %8343, %float9.999990e-06_10314, %int1_10315 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8344, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8345 = torch.aten.rsqrt %8344 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8345, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8346 = torch.aten.mul.Tensor %8340, %8345 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8346, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_10316 = torch.constant.int 5 + %8347 = torch.prims.convert_element_type %8346, %int5_10316 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8347, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %8348 = torch.aten.mul.Tensor %268, %8347 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8348, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_10317 = torch.constant.int 5 + %8349 = torch.prims.convert_element_type %8348, %int5_10317 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8349, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_10318 = torch.constant.int -2 + %int-1_10319 = torch.constant.int -1 + %8350 = torch.aten.transpose.int %269, %int-2_10318, %int-1_10319 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_10320 = torch.constant.int 5 + %8351 = torch.prims.convert_element_type %8350, %int5_10320 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_10321 = torch.constant.int 4096 + %8352 = torch.prim.ListConstruct %342, %int4096_10321 : (!torch.int, !torch.int) -> !torch.list + %8353 = torch.aten.view %8349, %8352 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8353, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8354 = torch.aten.mm %8353, %8351 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %8354, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_10322 = torch.constant.int 4 + %int14336_10323 = torch.constant.int 14336 + %8355 = torch.prim.ListConstruct %int4_10322, %298, %int14336_10323 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8356 = torch.aten.view %8354, %8355 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8356, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %8357 = torch.aten.silu %8356 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8357, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_10324 = torch.constant.int -2 + %int-1_10325 = torch.constant.int -1 + %8358 = torch.aten.transpose.int %270, %int-2_10324, %int-1_10325 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_10326 = torch.constant.int 5 + %8359 = torch.prims.convert_element_type %8358, %int5_10326 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_10327 = torch.constant.int 4096 + %8360 = torch.prim.ListConstruct %342, %int4096_10327 : (!torch.int, !torch.int) -> !torch.list + %8361 = torch.aten.view %8349, %8360 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8361, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8362 = torch.aten.mm %8361, %8359 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %8362, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_10328 = torch.constant.int 4 + %int14336_10329 = torch.constant.int 14336 + %8363 = torch.prim.ListConstruct %int4_10328, %298, %int14336_10329 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8364 = torch.aten.view %8362, %8363 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8364, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %8365 = torch.aten.mul.Tensor %8357, %8364 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8365, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_10330 = torch.constant.int -2 + %int-1_10331 = torch.constant.int -1 + %8366 = torch.aten.transpose.int %271, %int-2_10330, %int-1_10331 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_10332 = torch.constant.int 5 + %8367 = torch.prims.convert_element_type %8366, %int5_10332 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_10333 = torch.constant.int 14336 + %8368 = torch.prim.ListConstruct %342, %int14336_10333 : (!torch.int, !torch.int) -> !torch.list + %8369 = torch.aten.view %8365, %8368 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %8369, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %8370 = torch.aten.mm %8369, %8367 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8370, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_10334 = torch.constant.int 4 + %int4096_10335 = torch.constant.int 4096 + %8371 = torch.prim.ListConstruct %int4_10334, %298, %int4096_10335 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8372 = torch.aten.view %8370, %8371 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8372, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_10336 = torch.constant.int 1 + %8373 = torch.aten.add.Tensor %8339, %8372, %int1_10336 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8373, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_10337 = torch.constant.int 6 + %8374 = torch.prims.convert_element_type %8373, %int6_10337 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8374, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_10338 = torch.constant.int 2 + %8375 = torch.aten.pow.Tensor_Scalar %8374, %int2_10338 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8375, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_10339 = torch.constant.int -1 + %8376 = torch.prim.ListConstruct %int-1_10339 : (!torch.int) -> !torch.list + %true_10340 = torch.constant.bool true + %none_10341 = torch.constant.none + %8377 = torch.aten.mean.dim %8375, %8376, %true_10340, %none_10341 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8377, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_10342 = torch.constant.float 9.9999997473787516E-6 + %int1_10343 = torch.constant.int 1 + %8378 = torch.aten.add.Scalar %8377, %float9.999990e-06_10342, %int1_10343 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8378, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8379 = torch.aten.rsqrt %8378 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8379, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8380 = torch.aten.mul.Tensor %8374, %8379 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8380, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_10344 = torch.constant.int 5 + %8381 = torch.prims.convert_element_type %8380, %int5_10344 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8381, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %8382 = torch.aten.mul.Tensor %272, %8381 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8382, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_10345 = torch.constant.int 5 + %8383 = torch.prims.convert_element_type %8382, %int5_10345 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8383, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_10346 = torch.constant.int -2 + %int-1_10347 = torch.constant.int -1 + %8384 = torch.aten.transpose.int %273, %int-2_10346, %int-1_10347 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_10348 = torch.constant.int 5 + %8385 = torch.prims.convert_element_type %8384, %int5_10348 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_10349 = torch.constant.int 4096 + %8386 = torch.prim.ListConstruct %342, %int4096_10349 : (!torch.int, !torch.int) -> !torch.list + %8387 = torch.aten.view %8383, %8386 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8387, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8388 = torch.aten.mm %8387, %8385 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8388, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_10350 = torch.constant.int 4 + %int4096_10351 = torch.constant.int 4096 + %8389 = torch.prim.ListConstruct %int4_10350, %298, %int4096_10351 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8390 = torch.aten.view %8388, %8389 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8390, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_10352 = torch.constant.int -2 + %int-1_10353 = torch.constant.int -1 + %8391 = torch.aten.transpose.int %274, %int-2_10352, %int-1_10353 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_10354 = torch.constant.int 5 + %8392 = torch.prims.convert_element_type %8391, %int5_10354 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_10355 = torch.constant.int 4096 + %8393 = torch.prim.ListConstruct %342, %int4096_10355 : (!torch.int, !torch.int) -> !torch.list + %8394 = torch.aten.view %8383, %8393 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8394, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8395 = torch.aten.mm %8394, %8392 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %8395, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_10356 = torch.constant.int 4 + %int1024_10357 = torch.constant.int 1024 + %8396 = torch.prim.ListConstruct %int4_10356, %298, %int1024_10357 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8397 = torch.aten.view %8395, %8396 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %8397, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_10358 = torch.constant.int -2 + %int-1_10359 = torch.constant.int -1 + %8398 = torch.aten.transpose.int %275, %int-2_10358, %int-1_10359 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_10360 = torch.constant.int 5 + %8399 = torch.prims.convert_element_type %8398, %int5_10360 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_10361 = torch.constant.int 4096 + %8400 = torch.prim.ListConstruct %342, %int4096_10361 : (!torch.int, !torch.int) -> !torch.list + %8401 = torch.aten.view %8383, %8400 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8401, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8402 = torch.aten.mm %8401, %8399 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %8402, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_10362 = torch.constant.int 4 + %int1024_10363 = torch.constant.int 1024 + %8403 = torch.prim.ListConstruct %int4_10362, %298, %int1024_10363 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8404 = torch.aten.view %8402, %8403 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %8404, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_10364 = torch.constant.int 4 + %int32_10365 = torch.constant.int 32 + %int128_10366 = torch.constant.int 128 + %8405 = torch.prim.ListConstruct %int4_10364, %298, %int32_10365, %int128_10366 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8406 = torch.aten.view %8390, %8405 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8406, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_10367 = torch.constant.int 4 + %int8_10368 = torch.constant.int 8 + %int128_10369 = torch.constant.int 128 + %8407 = torch.prim.ListConstruct %int4_10367, %298, %int8_10368, %int128_10369 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8408 = torch.aten.view %8397, %8407 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8408, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_10370 = torch.constant.int 4 + %int8_10371 = torch.constant.int 8 + %int128_10372 = torch.constant.int 128 + %8409 = torch.prim.ListConstruct %int4_10370, %298, %int8_10371, %int128_10372 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8410 = torch.aten.view %8404, %8409 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8410, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_10373 = torch.constant.int 131072 + %none_10374 = torch.constant.none + %none_10375 = torch.constant.none + %cpu_10376 = torch.constant.device "cpu" + %false_10377 = torch.constant.bool false + %8411 = torch.aten.arange %int131072_10373, %none_10374, %none_10375, %cpu_10376, %false_10377 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_10378 = torch.constant.int 0 + %int128_10379 = torch.constant.int 128 + %int2_10380 = torch.constant.int 2 + %int4_10381 = torch.constant.int 4 + %none_10382 = torch.constant.none + %cpu_10383 = torch.constant.device "cpu" + %false_10384 = torch.constant.bool false + %8412 = torch.aten.arange.start_step %int0_10378, %int128_10379, %int2_10380, %int4_10381, %none_10382, %cpu_10383, %false_10384 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_10385 = torch.constant.int 6 + %8413 = torch.prims.convert_element_type %8412, %int6_10385 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_10386 = torch.constant.int 128 + %8414 = torch.aten.div.Scalar %8413, %int128_10386 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_10387 = torch.constant.float 5.000000e+05 + %8415 = torch.aten.pow.Scalar %float5.000000e05_10387, %8414 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8416 = torch.aten.reciprocal %8415 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_10388 = torch.constant.float 1.000000e+00 + %8417 = torch.aten.mul.Scalar %8416, %float1.000000e00_10388 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %8418 = torch.aten.reciprocal %8417 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_10389 = torch.constant.float 6.2831853071795862 + %8419 = torch.aten.mul.Scalar %8418, %float6.283190e00_10389 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_10390 = torch.constant.float 8.192000e+03 + %8420 = torch.aten.gt.Scalar %8419, %float8.192000e03_10390 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_10391 = torch.constant.int 8 + %8421 = torch.aten.div.Scalar %8417, %int8_10391 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %8422 = torch.aten.where.self %8420, %8421, %8417 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8423 = torch.aten.reciprocal %8419 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_10392 = torch.constant.int 8192 + %8424 = torch.aten.mul.Scalar %8423, %int8192_10392 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_10393 = torch.constant.int 1 + %int1_10394 = torch.constant.int 1 + %8425 = torch.aten.sub.Scalar %8424, %int1_10393, %int1_10394 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_10395 = torch.constant.int 3 + %8426 = torch.aten.div.Scalar %8425, %int3_10395 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_10396 = torch.constant.int 1 + %int1_10397 = torch.constant.int 1 + %8427 = torch.aten.rsub.Scalar %8426, %int1_10396, %int1_10397 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %8428 = torch.aten.mul.Tensor %8427, %8422 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_10398 = torch.constant.int 8 + %8429 = torch.aten.div.Scalar %8428, %int8_10398 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %8430 = torch.aten.mul.Tensor %8426, %8422 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_10399 = torch.constant.int 1 + %8431 = torch.aten.add.Tensor %8429, %8430, %int1_10399 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_10400 = torch.constant.float 2.048000e+03 + %8432 = torch.aten.lt.Scalar %8419, %float2.048000e03_10400 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %8433 = torch.aten.bitwise_not %8432 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_10401 = torch.constant.float 8.192000e+03 + %8434 = torch.aten.gt.Scalar %8419, %float8.192000e03_10401 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %8435 = torch.aten.bitwise_not %8434 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %8436 = torch.aten.mul.Tensor %8433, %8435 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %8437 = torch.aten.where.self %8436, %8431, %8422 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8438 = torch.prim.ListConstruct %8437, %8437 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_10402 = torch.constant.int -1 + %8439 = torch.aten.cat %8438, %int-1_10402 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_10403 = torch.constant.int 6 + %8440 = torch.prims.convert_element_type %8439, %int6_10403 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_10404 = torch.constant.int 1 + %8441 = torch.aten.unsqueeze %8411, %int1_10404 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_10405 = torch.constant.int 6 + %8442 = torch.prims.convert_element_type %8441, %int6_10405 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_10406 = torch.constant.int 0 + %8443 = torch.aten.unsqueeze %8440, %int0_10406 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_10407 = torch.constant.int 6 + %8444 = torch.prims.convert_element_type %8443, %int6_10407 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %8445 = torch.aten.mul.Tensor %8442, %8444 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %8446 = torch.aten.cos %8445 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_10408 = torch.constant.int 5 + %8447 = torch.prims.convert_element_type %8446, %int5_10408 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %8448 = torch.aten.sin %8445 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_10409 = torch.constant.int 5 + %8449 = torch.prims.convert_element_type %8448, %int5_10409 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_10410 = torch.constant.int 0 + %int0_10411 = torch.constant.int 0 + %int1_10412 = torch.constant.int 1 + %8450 = torch.aten.slice.Tensor %8447, %int0_10410, %int0_10411, %298, %int1_10412 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8450, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_10413 = torch.constant.int 1 + %int0_10414 = torch.constant.int 0 + %int9223372036854775807_10415 = torch.constant.int 9223372036854775807 + %int1_10416 = torch.constant.int 1 + %8451 = torch.aten.slice.Tensor %8450, %int1_10413, %int0_10414, %int9223372036854775807_10415, %int1_10416 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8451, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_10417 = torch.constant.int 0 + %int0_10418 = torch.constant.int 0 + %int1_10419 = torch.constant.int 1 + %8452 = torch.aten.slice.Tensor %8449, %int0_10417, %int0_10418, %298, %int1_10419 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8452, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_10420 = torch.constant.int 1 + %int0_10421 = torch.constant.int 0 + %int9223372036854775807_10422 = torch.constant.int 9223372036854775807 + %int1_10423 = torch.constant.int 1 + %8453 = torch.aten.slice.Tensor %8452, %int1_10420, %int0_10421, %int9223372036854775807_10422, %int1_10423 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8453, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_10424 = torch.constant.int 0 + %8454 = torch.aten.unsqueeze %8451, %int0_10424 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8454, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_10425 = torch.constant.int 1 + %int0_10426 = torch.constant.int 0 + %int9223372036854775807_10427 = torch.constant.int 9223372036854775807 + %int1_10428 = torch.constant.int 1 + %8455 = torch.aten.slice.Tensor %8454, %int1_10425, %int0_10426, %int9223372036854775807_10427, %int1_10428 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8455, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_10429 = torch.constant.int 2 + %8456 = torch.aten.unsqueeze %8455, %int2_10429 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8456, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_10430 = torch.constant.int 3 + %int0_10431 = torch.constant.int 0 + %int9223372036854775807_10432 = torch.constant.int 9223372036854775807 + %int1_10433 = torch.constant.int 1 + %8457 = torch.aten.slice.Tensor %8456, %int3_10430, %int0_10431, %int9223372036854775807_10432, %int1_10433 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8457, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_10434 = torch.constant.int 4 + %int1_10435 = torch.constant.int 1 + %int1_10436 = torch.constant.int 1 + %int1_10437 = torch.constant.int 1 + %8458 = torch.prim.ListConstruct %int4_10434, %int1_10435, %int1_10436, %int1_10437 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8459 = torch.aten.repeat %8457, %8458 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %8459, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_10438 = torch.constant.int 0 + %8460 = torch.aten.unsqueeze %8453, %int0_10438 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8460, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_10439 = torch.constant.int 1 + %int0_10440 = torch.constant.int 0 + %int9223372036854775807_10441 = torch.constant.int 9223372036854775807 + %int1_10442 = torch.constant.int 1 + %8461 = torch.aten.slice.Tensor %8460, %int1_10439, %int0_10440, %int9223372036854775807_10441, %int1_10442 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8461, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_10443 = torch.constant.int 2 + %8462 = torch.aten.unsqueeze %8461, %int2_10443 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8462, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_10444 = torch.constant.int 3 + %int0_10445 = torch.constant.int 0 + %int9223372036854775807_10446 = torch.constant.int 9223372036854775807 + %int1_10447 = torch.constant.int 1 + %8463 = torch.aten.slice.Tensor %8462, %int3_10444, %int0_10445, %int9223372036854775807_10446, %int1_10447 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8463, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_10448 = torch.constant.int 4 + %int1_10449 = torch.constant.int 1 + %int1_10450 = torch.constant.int 1 + %int1_10451 = torch.constant.int 1 + %8464 = torch.prim.ListConstruct %int4_10448, %int1_10449, %int1_10450, %int1_10451 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8465 = torch.aten.repeat %8463, %8464 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %8465, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %8466 = torch.aten.mul.Tensor %8406, %8459 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8466, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_10452 = torch.constant.int 3 + %int0_10453 = torch.constant.int 0 + %int64_10454 = torch.constant.int 64 + %int1_10455 = torch.constant.int 1 + %8467 = torch.aten.slice.Tensor %8406, %int3_10452, %int0_10453, %int64_10454, %int1_10455 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %8467, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_10456 = torch.constant.int 3 + %int64_10457 = torch.constant.int 64 + %int9223372036854775807_10458 = torch.constant.int 9223372036854775807 + %int1_10459 = torch.constant.int 1 + %8468 = torch.aten.slice.Tensor %8406, %int3_10456, %int64_10457, %int9223372036854775807_10458, %int1_10459 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %8468, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %8469 = torch.aten.neg %8468 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %8469, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %8470 = torch.prim.ListConstruct %8469, %8467 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_10460 = torch.constant.int -1 + %8471 = torch.aten.cat %8470, %int-1_10460 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8471, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %8472 = torch.aten.mul.Tensor %8471, %8465 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8472, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_10461 = torch.constant.int 1 + %8473 = torch.aten.add.Tensor %8466, %8472, %int1_10461 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8473, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_10462 = torch.constant.int 131072 + %none_10463 = torch.constant.none + %none_10464 = torch.constant.none + %cpu_10465 = torch.constant.device "cpu" + %false_10466 = torch.constant.bool false + %8474 = torch.aten.arange %int131072_10462, %none_10463, %none_10464, %cpu_10465, %false_10466 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_10467 = torch.constant.int 0 + %int128_10468 = torch.constant.int 128 + %int2_10469 = torch.constant.int 2 + %int4_10470 = torch.constant.int 4 + %none_10471 = torch.constant.none + %cpu_10472 = torch.constant.device "cpu" + %false_10473 = torch.constant.bool false + %8475 = torch.aten.arange.start_step %int0_10467, %int128_10468, %int2_10469, %int4_10470, %none_10471, %cpu_10472, %false_10473 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_10474 = torch.constant.int 6 + %8476 = torch.prims.convert_element_type %8475, %int6_10474 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_10475 = torch.constant.int 128 + %8477 = torch.aten.div.Scalar %8476, %int128_10475 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_10476 = torch.constant.float 5.000000e+05 + %8478 = torch.aten.pow.Scalar %float5.000000e05_10476, %8477 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8479 = torch.aten.reciprocal %8478 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_10477 = torch.constant.float 1.000000e+00 + %8480 = torch.aten.mul.Scalar %8479, %float1.000000e00_10477 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %8481 = torch.aten.reciprocal %8480 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_10478 = torch.constant.float 6.2831853071795862 + %8482 = torch.aten.mul.Scalar %8481, %float6.283190e00_10478 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_10479 = torch.constant.float 8.192000e+03 + %8483 = torch.aten.gt.Scalar %8482, %float8.192000e03_10479 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_10480 = torch.constant.int 8 + %8484 = torch.aten.div.Scalar %8480, %int8_10480 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %8485 = torch.aten.where.self %8483, %8484, %8480 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8486 = torch.aten.reciprocal %8482 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_10481 = torch.constant.int 8192 + %8487 = torch.aten.mul.Scalar %8486, %int8192_10481 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_10482 = torch.constant.int 1 + %int1_10483 = torch.constant.int 1 + %8488 = torch.aten.sub.Scalar %8487, %int1_10482, %int1_10483 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_10484 = torch.constant.int 3 + %8489 = torch.aten.div.Scalar %8488, %int3_10484 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_10485 = torch.constant.int 1 + %int1_10486 = torch.constant.int 1 + %8490 = torch.aten.rsub.Scalar %8489, %int1_10485, %int1_10486 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %8491 = torch.aten.mul.Tensor %8490, %8485 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_10487 = torch.constant.int 8 + %8492 = torch.aten.div.Scalar %8491, %int8_10487 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %8493 = torch.aten.mul.Tensor %8489, %8485 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_10488 = torch.constant.int 1 + %8494 = torch.aten.add.Tensor %8492, %8493, %int1_10488 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_10489 = torch.constant.float 2.048000e+03 + %8495 = torch.aten.lt.Scalar %8482, %float2.048000e03_10489 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %8496 = torch.aten.bitwise_not %8495 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_10490 = torch.constant.float 8.192000e+03 + %8497 = torch.aten.gt.Scalar %8482, %float8.192000e03_10490 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %8498 = torch.aten.bitwise_not %8497 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %8499 = torch.aten.mul.Tensor %8496, %8498 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %8500 = torch.aten.where.self %8499, %8494, %8485 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8501 = torch.prim.ListConstruct %8500, %8500 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_10491 = torch.constant.int -1 + %8502 = torch.aten.cat %8501, %int-1_10491 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_10492 = torch.constant.int 6 + %8503 = torch.prims.convert_element_type %8502, %int6_10492 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_10493 = torch.constant.int 1 + %8504 = torch.aten.unsqueeze %8474, %int1_10493 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_10494 = torch.constant.int 6 + %8505 = torch.prims.convert_element_type %8504, %int6_10494 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_10495 = torch.constant.int 0 + %8506 = torch.aten.unsqueeze %8503, %int0_10495 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_10496 = torch.constant.int 6 + %8507 = torch.prims.convert_element_type %8506, %int6_10496 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %8508 = torch.aten.mul.Tensor %8505, %8507 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %8509 = torch.aten.cos %8508 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_10497 = torch.constant.int 5 + %8510 = torch.prims.convert_element_type %8509, %int5_10497 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %8511 = torch.aten.sin %8508 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_10498 = torch.constant.int 5 + %8512 = torch.prims.convert_element_type %8511, %int5_10498 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_10499 = torch.constant.int 0 + %int0_10500 = torch.constant.int 0 + %int1_10501 = torch.constant.int 1 + %8513 = torch.aten.slice.Tensor %8510, %int0_10499, %int0_10500, %298, %int1_10501 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8513, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_10502 = torch.constant.int 1 + %int0_10503 = torch.constant.int 0 + %int9223372036854775807_10504 = torch.constant.int 9223372036854775807 + %int1_10505 = torch.constant.int 1 + %8514 = torch.aten.slice.Tensor %8513, %int1_10502, %int0_10503, %int9223372036854775807_10504, %int1_10505 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8514, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_10506 = torch.constant.int 0 + %int0_10507 = torch.constant.int 0 + %int1_10508 = torch.constant.int 1 + %8515 = torch.aten.slice.Tensor %8512, %int0_10506, %int0_10507, %298, %int1_10508 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8515, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_10509 = torch.constant.int 1 + %int0_10510 = torch.constant.int 0 + %int9223372036854775807_10511 = torch.constant.int 9223372036854775807 + %int1_10512 = torch.constant.int 1 + %8516 = torch.aten.slice.Tensor %8515, %int1_10509, %int0_10510, %int9223372036854775807_10511, %int1_10512 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8516, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_10513 = torch.constant.int 0 + %8517 = torch.aten.unsqueeze %8514, %int0_10513 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8517, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_10514 = torch.constant.int 1 + %int0_10515 = torch.constant.int 0 + %int9223372036854775807_10516 = torch.constant.int 9223372036854775807 + %int1_10517 = torch.constant.int 1 + %8518 = torch.aten.slice.Tensor %8517, %int1_10514, %int0_10515, %int9223372036854775807_10516, %int1_10517 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8518, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_10518 = torch.constant.int 2 + %8519 = torch.aten.unsqueeze %8518, %int2_10518 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8519, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_10519 = torch.constant.int 3 + %int0_10520 = torch.constant.int 0 + %int9223372036854775807_10521 = torch.constant.int 9223372036854775807 + %int1_10522 = torch.constant.int 1 + %8520 = torch.aten.slice.Tensor %8519, %int3_10519, %int0_10520, %int9223372036854775807_10521, %int1_10522 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8520, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_10523 = torch.constant.int 4 + %int1_10524 = torch.constant.int 1 + %int1_10525 = torch.constant.int 1 + %int1_10526 = torch.constant.int 1 + %8521 = torch.prim.ListConstruct %int4_10523, %int1_10524, %int1_10525, %int1_10526 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8522 = torch.aten.repeat %8520, %8521 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %8522, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_10527 = torch.constant.int 0 + %8523 = torch.aten.unsqueeze %8516, %int0_10527 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8523, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_10528 = torch.constant.int 1 + %int0_10529 = torch.constant.int 0 + %int9223372036854775807_10530 = torch.constant.int 9223372036854775807 + %int1_10531 = torch.constant.int 1 + %8524 = torch.aten.slice.Tensor %8523, %int1_10528, %int0_10529, %int9223372036854775807_10530, %int1_10531 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8524, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_10532 = torch.constant.int 2 + %8525 = torch.aten.unsqueeze %8524, %int2_10532 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8525, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_10533 = torch.constant.int 3 + %int0_10534 = torch.constant.int 0 + %int9223372036854775807_10535 = torch.constant.int 9223372036854775807 + %int1_10536 = torch.constant.int 1 + %8526 = torch.aten.slice.Tensor %8525, %int3_10533, %int0_10534, %int9223372036854775807_10535, %int1_10536 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8526, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_10537 = torch.constant.int 4 + %int1_10538 = torch.constant.int 1 + %int1_10539 = torch.constant.int 1 + %int1_10540 = torch.constant.int 1 + %8527 = torch.prim.ListConstruct %int4_10537, %int1_10538, %int1_10539, %int1_10540 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8528 = torch.aten.repeat %8526, %8527 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %8528, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %8529 = torch.aten.mul.Tensor %8408, %8522 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8529, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_10541 = torch.constant.int 3 + %int0_10542 = torch.constant.int 0 + %int64_10543 = torch.constant.int 64 + %int1_10544 = torch.constant.int 1 + %8530 = torch.aten.slice.Tensor %8408, %int3_10541, %int0_10542, %int64_10543, %int1_10544 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %8530, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_10545 = torch.constant.int 3 + %int64_10546 = torch.constant.int 64 + %int9223372036854775807_10547 = torch.constant.int 9223372036854775807 + %int1_10548 = torch.constant.int 1 + %8531 = torch.aten.slice.Tensor %8408, %int3_10545, %int64_10546, %int9223372036854775807_10547, %int1_10548 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %8531, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %8532 = torch.aten.neg %8531 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %8532, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %8533 = torch.prim.ListConstruct %8532, %8530 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_10549 = torch.constant.int -1 + %8534 = torch.aten.cat %8533, %int-1_10549 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8534, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %8535 = torch.aten.mul.Tensor %8534, %8528 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8535, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_10550 = torch.constant.int 1 + %8536 = torch.aten.add.Tensor %8529, %8535, %int1_10550 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8536, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_10551 = torch.constant.int 32 + %8537 = torch.aten.mul.Scalar %arg2, %int32_10551 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8537, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int30 = torch.constant.int 30 + %int1_10552 = torch.constant.int 1 + %8538 = torch.aten.add.Scalar %8537, %int30, %int1_10552 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8538, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_10553 = torch.constant.int 2 + %8539 = torch.aten.mul.Scalar %8538, %int2_10553 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8539, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_10554 = torch.constant.int 0 + %int1_10555 = torch.constant.int 1 + %8540 = torch.aten.add.Scalar %8539, %int0_10554, %int1_10555 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8540, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %8541 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %8542 = torch.aten.view %8540, %8541 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %8542, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_10556 = torch.constant.int 4 + %int32_10557 = torch.constant.int 32 + %int8_10558 = torch.constant.int 8 + %int128_10559 = torch.constant.int 128 + %8543 = torch.prim.ListConstruct %int4_10556, %296, %int32_10557, %int8_10558, %int128_10559 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8544 = torch.aten.view %8536, %8543 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %8544, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_10560 = torch.constant.int 32 + %int8_10561 = torch.constant.int 8 + %int128_10562 = torch.constant.int 128 + %8545 = torch.prim.ListConstruct %504, %int32_10560, %int8_10561, %int128_10562 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8546 = torch.aten.view %8544, %8545 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %8546, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_10563 = torch.constant.int 1 + %int2_10564 = torch.constant.int 2 + %8547 = torch.aten.transpose.int %8546, %int1_10563, %int2_10564 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8547, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_10565 = torch.constant.int 5 + %8548 = torch.prims.convert_element_type %8547, %int5_10565 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8548, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_10566 = torch.constant.int 32 + %int2_10567 = torch.constant.int 2 + %int8_10568 = torch.constant.int 8 + %int32_10569 = torch.constant.int 32 + %int128_10570 = torch.constant.int 128 + %8549 = torch.prim.ListConstruct %297, %int32_10566, %int2_10567, %int8_10568, %int32_10569, %int128_10570 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8550 = torch.aten.view %8312, %8549 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8550, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_10571 = torch.constant.int 8 + %int32_10572 = torch.constant.int 32 + %int128_10573 = torch.constant.int 128 + %8551 = torch.prim.ListConstruct %497, %int8_10571, %int32_10572, %int128_10573 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8552 = torch.aten.view %8550, %8551 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8552, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %8553 = torch.prim.ListConstruct %8542 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_10574 = torch.constant.bool false + %8554 = torch.aten.index_put %8552, %8553, %8548, %false_10574 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8554, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_10575 = torch.constant.int 32 + %int2_10576 = torch.constant.int 2 + %int8_10577 = torch.constant.int 8 + %int32_10578 = torch.constant.int 32 + %int128_10579 = torch.constant.int 128 + %8555 = torch.prim.ListConstruct %297, %int32_10575, %int2_10576, %int8_10577, %int32_10578, %int128_10579 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8556 = torch.aten.view %8554, %8555 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8556, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_10580 = torch.constant.int 2097152 + %8557 = torch.prim.ListConstruct %297, %int2097152_10580 : (!torch.int, !torch.int) -> !torch.list + %8558 = torch.aten.view %8556, %8557 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %8558, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_10581 = torch.constant.int 32 + %int2_10582 = torch.constant.int 2 + %int8_10583 = torch.constant.int 8 + %int32_10584 = torch.constant.int 32 + %int128_10585 = torch.constant.int 128 + %8559 = torch.prim.ListConstruct %297, %int32_10581, %int2_10582, %int8_10583, %int32_10584, %int128_10585 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8560 = torch.aten.view %8558, %8559 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8560, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_10586 = torch.constant.int 8 + %int32_10587 = torch.constant.int 32 + %int128_10588 = torch.constant.int 128 + %8561 = torch.prim.ListConstruct %497, %int8_10586, %int32_10587, %int128_10588 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8562 = torch.aten.view %8560, %8561 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8562, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_10589 = torch.constant.int 32 + %8563 = torch.aten.mul.Scalar %arg2, %int32_10589 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8563, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int30_10590 = torch.constant.int 30 + %int1_10591 = torch.constant.int 1 + %8564 = torch.aten.add.Scalar %8563, %int30_10590, %int1_10591 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8564, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_10592 = torch.constant.int 2 + %8565 = torch.aten.mul.Scalar %8564, %int2_10592 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8565, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_10593 = torch.constant.int 1 + %int1_10594 = torch.constant.int 1 + %8566 = torch.aten.add.Scalar %8565, %int1_10593, %int1_10594 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8566, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %8567 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %8568 = torch.aten.view %8566, %8567 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %8568, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_10595 = torch.constant.int 4 + %int32_10596 = torch.constant.int 32 + %int8_10597 = torch.constant.int 8 + %int128_10598 = torch.constant.int 128 + %8569 = torch.prim.ListConstruct %int4_10595, %296, %int32_10596, %int8_10597, %int128_10598 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8570 = torch.aten.view %8410, %8569 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %8570, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_10599 = torch.constant.int 32 + %int8_10600 = torch.constant.int 8 + %int128_10601 = torch.constant.int 128 + %8571 = torch.prim.ListConstruct %504, %int32_10599, %int8_10600, %int128_10601 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8572 = torch.aten.view %8570, %8571 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %8572, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_10602 = torch.constant.int 1 + %int2_10603 = torch.constant.int 2 + %8573 = torch.aten.transpose.int %8572, %int1_10602, %int2_10603 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8573, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_10604 = torch.constant.int 5 + %8574 = torch.prims.convert_element_type %8573, %int5_10604 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8574, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %8575 = torch.prim.ListConstruct %8568 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_10605 = torch.constant.bool false + %8576 = torch.aten.index_put %8562, %8575, %8574, %false_10605 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8576, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_10606 = torch.constant.int 32 + %int2_10607 = torch.constant.int 2 + %int8_10608 = torch.constant.int 8 + %int32_10609 = torch.constant.int 32 + %int128_10610 = torch.constant.int 128 + %8577 = torch.prim.ListConstruct %297, %int32_10606, %int2_10607, %int8_10608, %int32_10609, %int128_10610 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8578 = torch.aten.view %8576, %8577 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8578, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_10611 = torch.constant.int 2097152 + %8579 = torch.prim.ListConstruct %297, %int2097152_10611 : (!torch.int, !torch.int) -> !torch.list + %8580 = torch.aten.view %8578, %8579 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %8580, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_10612 = torch.constant.int -2 + %8581 = torch.aten.unsqueeze %8536, %int-2_10612 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %8581, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_10613 = torch.constant.int 4 + %int8_10614 = torch.constant.int 8 + %int4_10615 = torch.constant.int 4 + %int128_10616 = torch.constant.int 128 + %8582 = torch.prim.ListConstruct %int4_10613, %298, %int8_10614, %int4_10615, %int128_10616 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_10617 = torch.constant.bool false + %8583 = torch.aten.expand %8581, %8582, %false_10617 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8583, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_10618 = torch.constant.int 0 + %8584 = torch.aten.clone %8583, %int0_10618 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8584, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_10619 = torch.constant.int 4 + %int32_10620 = torch.constant.int 32 + %int128_10621 = torch.constant.int 128 + %8585 = torch.prim.ListConstruct %int4_10619, %298, %int32_10620, %int128_10621 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8586 = torch.aten._unsafe_view %8584, %8585 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8586, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_10622 = torch.constant.int -2 + %8587 = torch.aten.unsqueeze %8410, %int-2_10622 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %8587, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_10623 = torch.constant.int 4 + %int8_10624 = torch.constant.int 8 + %int4_10625 = torch.constant.int 4 + %int128_10626 = torch.constant.int 128 + %8588 = torch.prim.ListConstruct %int4_10623, %298, %int8_10624, %int4_10625, %int128_10626 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_10627 = torch.constant.bool false + %8589 = torch.aten.expand %8587, %8588, %false_10627 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8589, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_10628 = torch.constant.int 0 + %8590 = torch.aten.clone %8589, %int0_10628 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8590, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_10629 = torch.constant.int 4 + %int32_10630 = torch.constant.int 32 + %int128_10631 = torch.constant.int 128 + %8591 = torch.prim.ListConstruct %int4_10629, %298, %int32_10630, %int128_10631 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8592 = torch.aten._unsafe_view %8590, %8591 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8592, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_10632 = torch.constant.int 1 + %int2_10633 = torch.constant.int 2 + %8593 = torch.aten.transpose.int %8473, %int1_10632, %int2_10633 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %8593, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_10634 = torch.constant.int 1 + %int2_10635 = torch.constant.int 2 + %8594 = torch.aten.transpose.int %8586, %int1_10634, %int2_10635 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %8594, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_10636 = torch.constant.int 1 + %int2_10637 = torch.constant.int 2 + %8595 = torch.aten.transpose.int %8592, %int1_10636, %int2_10637 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %8595, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_10638 = torch.constant.float 0.000000e+00 + %false_10639 = torch.constant.bool false + %none_10640 = torch.constant.none + %8596:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%8593, %8594, %8595, %float0.000000e00_10638, %false_10639, %327, %none_10640) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %8596#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_10641 = torch.constant.int 1 + %int2_10642 = torch.constant.int 2 + %8597 = torch.aten.transpose.int %8596#0, %int1_10641, %int2_10642 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8597, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_10643 = torch.constant.int 4 + %int4096_10644 = torch.constant.int 4096 + %8598 = torch.prim.ListConstruct %int4_10643, %298, %int4096_10644 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8599 = torch.aten.view %8597, %8598 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8599, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_10645 = torch.constant.int -2 + %int-1_10646 = torch.constant.int -1 + %8600 = torch.aten.transpose.int %276, %int-2_10645, %int-1_10646 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_10647 = torch.constant.int 5 + %8601 = torch.prims.convert_element_type %8600, %int5_10647 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_10648 = torch.constant.int 4096 + %8602 = torch.prim.ListConstruct %342, %int4096_10648 : (!torch.int, !torch.int) -> !torch.list + %8603 = torch.aten.view %8599, %8602 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8603, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8604 = torch.aten.mm %8603, %8601 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8604, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_10649 = torch.constant.int 4 + %int4096_10650 = torch.constant.int 4096 + %8605 = torch.prim.ListConstruct %int4_10649, %298, %int4096_10650 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8606 = torch.aten.view %8604, %8605 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8606, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_10651 = torch.constant.int 1 + %8607 = torch.aten.add.Tensor %8373, %8606, %int1_10651 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8607, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_10652 = torch.constant.int 6 + %8608 = torch.prims.convert_element_type %8607, %int6_10652 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8608, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_10653 = torch.constant.int 2 + %8609 = torch.aten.pow.Tensor_Scalar %8608, %int2_10653 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8609, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_10654 = torch.constant.int -1 + %8610 = torch.prim.ListConstruct %int-1_10654 : (!torch.int) -> !torch.list + %true_10655 = torch.constant.bool true + %none_10656 = torch.constant.none + %8611 = torch.aten.mean.dim %8609, %8610, %true_10655, %none_10656 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8611, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_10657 = torch.constant.float 9.9999997473787516E-6 + %int1_10658 = torch.constant.int 1 + %8612 = torch.aten.add.Scalar %8611, %float9.999990e-06_10657, %int1_10658 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8612, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8613 = torch.aten.rsqrt %8612 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8613, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8614 = torch.aten.mul.Tensor %8608, %8613 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8614, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_10659 = torch.constant.int 5 + %8615 = torch.prims.convert_element_type %8614, %int5_10659 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8615, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %8616 = torch.aten.mul.Tensor %277, %8615 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8616, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_10660 = torch.constant.int 5 + %8617 = torch.prims.convert_element_type %8616, %int5_10660 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8617, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_10661 = torch.constant.int -2 + %int-1_10662 = torch.constant.int -1 + %8618 = torch.aten.transpose.int %278, %int-2_10661, %int-1_10662 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_10663 = torch.constant.int 5 + %8619 = torch.prims.convert_element_type %8618, %int5_10663 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_10664 = torch.constant.int 4096 + %8620 = torch.prim.ListConstruct %342, %int4096_10664 : (!torch.int, !torch.int) -> !torch.list + %8621 = torch.aten.view %8617, %8620 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8621, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8622 = torch.aten.mm %8621, %8619 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %8622, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_10665 = torch.constant.int 4 + %int14336_10666 = torch.constant.int 14336 + %8623 = torch.prim.ListConstruct %int4_10665, %298, %int14336_10666 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8624 = torch.aten.view %8622, %8623 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8624, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %8625 = torch.aten.silu %8624 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8625, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_10667 = torch.constant.int -2 + %int-1_10668 = torch.constant.int -1 + %8626 = torch.aten.transpose.int %279, %int-2_10667, %int-1_10668 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_10669 = torch.constant.int 5 + %8627 = torch.prims.convert_element_type %8626, %int5_10669 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_10670 = torch.constant.int 4096 + %8628 = torch.prim.ListConstruct %342, %int4096_10670 : (!torch.int, !torch.int) -> !torch.list + %8629 = torch.aten.view %8617, %8628 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8629, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8630 = torch.aten.mm %8629, %8627 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %8630, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_10671 = torch.constant.int 4 + %int14336_10672 = torch.constant.int 14336 + %8631 = torch.prim.ListConstruct %int4_10671, %298, %int14336_10672 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8632 = torch.aten.view %8630, %8631 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8632, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %8633 = torch.aten.mul.Tensor %8625, %8632 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8633, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_10673 = torch.constant.int -2 + %int-1_10674 = torch.constant.int -1 + %8634 = torch.aten.transpose.int %280, %int-2_10673, %int-1_10674 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_10675 = torch.constant.int 5 + %8635 = torch.prims.convert_element_type %8634, %int5_10675 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_10676 = torch.constant.int 14336 + %8636 = torch.prim.ListConstruct %342, %int14336_10676 : (!torch.int, !torch.int) -> !torch.list + %8637 = torch.aten.view %8633, %8636 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %8637, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %8638 = torch.aten.mm %8637, %8635 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8638, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_10677 = torch.constant.int 4 + %int4096_10678 = torch.constant.int 4096 + %8639 = torch.prim.ListConstruct %int4_10677, %298, %int4096_10678 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8640 = torch.aten.view %8638, %8639 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8640, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_10679 = torch.constant.int 1 + %8641 = torch.aten.add.Tensor %8607, %8640, %int1_10679 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8641, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_10680 = torch.constant.int 6 + %8642 = torch.prims.convert_element_type %8641, %int6_10680 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8642, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_10681 = torch.constant.int 2 + %8643 = torch.aten.pow.Tensor_Scalar %8642, %int2_10681 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8643, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_10682 = torch.constant.int -1 + %8644 = torch.prim.ListConstruct %int-1_10682 : (!torch.int) -> !torch.list + %true_10683 = torch.constant.bool true + %none_10684 = torch.constant.none + %8645 = torch.aten.mean.dim %8643, %8644, %true_10683, %none_10684 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8645, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_10685 = torch.constant.float 9.9999997473787516E-6 + %int1_10686 = torch.constant.int 1 + %8646 = torch.aten.add.Scalar %8645, %float9.999990e-06_10685, %int1_10686 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8646, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8647 = torch.aten.rsqrt %8646 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8647, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8648 = torch.aten.mul.Tensor %8642, %8647 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8648, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_10687 = torch.constant.int 5 + %8649 = torch.prims.convert_element_type %8648, %int5_10687 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8649, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %8650 = torch.aten.mul.Tensor %281, %8649 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8650, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_10688 = torch.constant.int 5 + %8651 = torch.prims.convert_element_type %8650, %int5_10688 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8651, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_10689 = torch.constant.int -2 + %int-1_10690 = torch.constant.int -1 + %8652 = torch.aten.transpose.int %282, %int-2_10689, %int-1_10690 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_10691 = torch.constant.int 5 + %8653 = torch.prims.convert_element_type %8652, %int5_10691 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_10692 = torch.constant.int 4096 + %8654 = torch.prim.ListConstruct %342, %int4096_10692 : (!torch.int, !torch.int) -> !torch.list + %8655 = torch.aten.view %8651, %8654 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8655, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8656 = torch.aten.mm %8655, %8653 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8656, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_10693 = torch.constant.int 4 + %int4096_10694 = torch.constant.int 4096 + %8657 = torch.prim.ListConstruct %int4_10693, %298, %int4096_10694 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8658 = torch.aten.view %8656, %8657 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8658, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_10695 = torch.constant.int -2 + %int-1_10696 = torch.constant.int -1 + %8659 = torch.aten.transpose.int %283, %int-2_10695, %int-1_10696 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_10697 = torch.constant.int 5 + %8660 = torch.prims.convert_element_type %8659, %int5_10697 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_10698 = torch.constant.int 4096 + %8661 = torch.prim.ListConstruct %342, %int4096_10698 : (!torch.int, !torch.int) -> !torch.list + %8662 = torch.aten.view %8651, %8661 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8662, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8663 = torch.aten.mm %8662, %8660 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %8663, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_10699 = torch.constant.int 4 + %int1024_10700 = torch.constant.int 1024 + %8664 = torch.prim.ListConstruct %int4_10699, %298, %int1024_10700 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8665 = torch.aten.view %8663, %8664 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %8665, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int-2_10701 = torch.constant.int -2 + %int-1_10702 = torch.constant.int -1 + %8666 = torch.aten.transpose.int %284, %int-2_10701, %int-1_10702 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_10703 = torch.constant.int 5 + %8667 = torch.prims.convert_element_type %8666, %int5_10703 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_10704 = torch.constant.int 4096 + %8668 = torch.prim.ListConstruct %342, %int4096_10704 : (!torch.int, !torch.int) -> !torch.list + %8669 = torch.aten.view %8651, %8668 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8669, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8670 = torch.aten.mm %8669, %8667 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[?,1024],f16> + torch.bind_symbolic_shape %8670, [%294], affine_map<()[s0] -> (s0 * 128, 1024)> : !torch.vtensor<[?,1024],f16> + %int4_10705 = torch.constant.int 4 + %int1024_10706 = torch.constant.int 1024 + %8671 = torch.prim.ListConstruct %int4_10705, %298, %int1024_10706 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8672 = torch.aten.view %8670, %8671 : !torch.vtensor<[?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,1024],f16> + torch.bind_symbolic_shape %8672, [%294], affine_map<()[s0] -> (4, s0 * 32, 1024)> : !torch.vtensor<[4,?,1024],f16> + %int4_10707 = torch.constant.int 4 + %int32_10708 = torch.constant.int 32 + %int128_10709 = torch.constant.int 128 + %8673 = torch.prim.ListConstruct %int4_10707, %298, %int32_10708, %int128_10709 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8674 = torch.aten.view %8658, %8673 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8674, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_10710 = torch.constant.int 4 + %int8_10711 = torch.constant.int 8 + %int128_10712 = torch.constant.int 128 + %8675 = torch.prim.ListConstruct %int4_10710, %298, %int8_10711, %int128_10712 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8676 = torch.aten.view %8665, %8675 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8676, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int4_10713 = torch.constant.int 4 + %int8_10714 = torch.constant.int 8 + %int128_10715 = torch.constant.int 128 + %8677 = torch.prim.ListConstruct %int4_10713, %298, %int8_10714, %int128_10715 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8678 = torch.aten.view %8672, %8677 : !torch.vtensor<[4,?,1024],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8678, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int131072_10716 = torch.constant.int 131072 + %none_10717 = torch.constant.none + %none_10718 = torch.constant.none + %cpu_10719 = torch.constant.device "cpu" + %false_10720 = torch.constant.bool false + %8679 = torch.aten.arange %int131072_10716, %none_10717, %none_10718, %cpu_10719, %false_10720 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_10721 = torch.constant.int 0 + %int128_10722 = torch.constant.int 128 + %int2_10723 = torch.constant.int 2 + %int4_10724 = torch.constant.int 4 + %none_10725 = torch.constant.none + %cpu_10726 = torch.constant.device "cpu" + %false_10727 = torch.constant.bool false + %8680 = torch.aten.arange.start_step %int0_10721, %int128_10722, %int2_10723, %int4_10724, %none_10725, %cpu_10726, %false_10727 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_10728 = torch.constant.int 6 + %8681 = torch.prims.convert_element_type %8680, %int6_10728 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_10729 = torch.constant.int 128 + %8682 = torch.aten.div.Scalar %8681, %int128_10729 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_10730 = torch.constant.float 5.000000e+05 + %8683 = torch.aten.pow.Scalar %float5.000000e05_10730, %8682 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8684 = torch.aten.reciprocal %8683 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_10731 = torch.constant.float 1.000000e+00 + %8685 = torch.aten.mul.Scalar %8684, %float1.000000e00_10731 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %8686 = torch.aten.reciprocal %8685 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_10732 = torch.constant.float 6.2831853071795862 + %8687 = torch.aten.mul.Scalar %8686, %float6.283190e00_10732 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_10733 = torch.constant.float 8.192000e+03 + %8688 = torch.aten.gt.Scalar %8687, %float8.192000e03_10733 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_10734 = torch.constant.int 8 + %8689 = torch.aten.div.Scalar %8685, %int8_10734 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %8690 = torch.aten.where.self %8688, %8689, %8685 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8691 = torch.aten.reciprocal %8687 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_10735 = torch.constant.int 8192 + %8692 = torch.aten.mul.Scalar %8691, %int8192_10735 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_10736 = torch.constant.int 1 + %int1_10737 = torch.constant.int 1 + %8693 = torch.aten.sub.Scalar %8692, %int1_10736, %int1_10737 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_10738 = torch.constant.int 3 + %8694 = torch.aten.div.Scalar %8693, %int3_10738 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_10739 = torch.constant.int 1 + %int1_10740 = torch.constant.int 1 + %8695 = torch.aten.rsub.Scalar %8694, %int1_10739, %int1_10740 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %8696 = torch.aten.mul.Tensor %8695, %8690 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_10741 = torch.constant.int 8 + %8697 = torch.aten.div.Scalar %8696, %int8_10741 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %8698 = torch.aten.mul.Tensor %8694, %8690 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_10742 = torch.constant.int 1 + %8699 = torch.aten.add.Tensor %8697, %8698, %int1_10742 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_10743 = torch.constant.float 2.048000e+03 + %8700 = torch.aten.lt.Scalar %8687, %float2.048000e03_10743 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %8701 = torch.aten.bitwise_not %8700 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_10744 = torch.constant.float 8.192000e+03 + %8702 = torch.aten.gt.Scalar %8687, %float8.192000e03_10744 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %8703 = torch.aten.bitwise_not %8702 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %8704 = torch.aten.mul.Tensor %8701, %8703 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %8705 = torch.aten.where.self %8704, %8699, %8690 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8706 = torch.prim.ListConstruct %8705, %8705 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_10745 = torch.constant.int -1 + %8707 = torch.aten.cat %8706, %int-1_10745 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_10746 = torch.constant.int 6 + %8708 = torch.prims.convert_element_type %8707, %int6_10746 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_10747 = torch.constant.int 1 + %8709 = torch.aten.unsqueeze %8679, %int1_10747 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_10748 = torch.constant.int 6 + %8710 = torch.prims.convert_element_type %8709, %int6_10748 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_10749 = torch.constant.int 0 + %8711 = torch.aten.unsqueeze %8708, %int0_10749 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_10750 = torch.constant.int 6 + %8712 = torch.prims.convert_element_type %8711, %int6_10750 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %8713 = torch.aten.mul.Tensor %8710, %8712 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %8714 = torch.aten.cos %8713 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_10751 = torch.constant.int 5 + %8715 = torch.prims.convert_element_type %8714, %int5_10751 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %8716 = torch.aten.sin %8713 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_10752 = torch.constant.int 5 + %8717 = torch.prims.convert_element_type %8716, %int5_10752 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_10753 = torch.constant.int 0 + %int0_10754 = torch.constant.int 0 + %int1_10755 = torch.constant.int 1 + %8718 = torch.aten.slice.Tensor %8715, %int0_10753, %int0_10754, %298, %int1_10755 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8718, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_10756 = torch.constant.int 1 + %int0_10757 = torch.constant.int 0 + %int9223372036854775807_10758 = torch.constant.int 9223372036854775807 + %int1_10759 = torch.constant.int 1 + %8719 = torch.aten.slice.Tensor %8718, %int1_10756, %int0_10757, %int9223372036854775807_10758, %int1_10759 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8719, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_10760 = torch.constant.int 0 + %int0_10761 = torch.constant.int 0 + %int1_10762 = torch.constant.int 1 + %8720 = torch.aten.slice.Tensor %8717, %int0_10760, %int0_10761, %298, %int1_10762 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8720, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_10763 = torch.constant.int 1 + %int0_10764 = torch.constant.int 0 + %int9223372036854775807_10765 = torch.constant.int 9223372036854775807 + %int1_10766 = torch.constant.int 1 + %8721 = torch.aten.slice.Tensor %8720, %int1_10763, %int0_10764, %int9223372036854775807_10765, %int1_10766 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8721, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_10767 = torch.constant.int 0 + %8722 = torch.aten.unsqueeze %8719, %int0_10767 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8722, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_10768 = torch.constant.int 1 + %int0_10769 = torch.constant.int 0 + %int9223372036854775807_10770 = torch.constant.int 9223372036854775807 + %int1_10771 = torch.constant.int 1 + %8723 = torch.aten.slice.Tensor %8722, %int1_10768, %int0_10769, %int9223372036854775807_10770, %int1_10771 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8723, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_10772 = torch.constant.int 2 + %8724 = torch.aten.unsqueeze %8723, %int2_10772 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8724, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_10773 = torch.constant.int 3 + %int0_10774 = torch.constant.int 0 + %int9223372036854775807_10775 = torch.constant.int 9223372036854775807 + %int1_10776 = torch.constant.int 1 + %8725 = torch.aten.slice.Tensor %8724, %int3_10773, %int0_10774, %int9223372036854775807_10775, %int1_10776 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8725, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_10777 = torch.constant.int 4 + %int1_10778 = torch.constant.int 1 + %int1_10779 = torch.constant.int 1 + %int1_10780 = torch.constant.int 1 + %8726 = torch.prim.ListConstruct %int4_10777, %int1_10778, %int1_10779, %int1_10780 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8727 = torch.aten.repeat %8725, %8726 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %8727, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_10781 = torch.constant.int 0 + %8728 = torch.aten.unsqueeze %8721, %int0_10781 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8728, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_10782 = torch.constant.int 1 + %int0_10783 = torch.constant.int 0 + %int9223372036854775807_10784 = torch.constant.int 9223372036854775807 + %int1_10785 = torch.constant.int 1 + %8729 = torch.aten.slice.Tensor %8728, %int1_10782, %int0_10783, %int9223372036854775807_10784, %int1_10785 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8729, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_10786 = torch.constant.int 2 + %8730 = torch.aten.unsqueeze %8729, %int2_10786 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8730, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_10787 = torch.constant.int 3 + %int0_10788 = torch.constant.int 0 + %int9223372036854775807_10789 = torch.constant.int 9223372036854775807 + %int1_10790 = torch.constant.int 1 + %8731 = torch.aten.slice.Tensor %8730, %int3_10787, %int0_10788, %int9223372036854775807_10789, %int1_10790 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8731, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_10791 = torch.constant.int 4 + %int1_10792 = torch.constant.int 1 + %int1_10793 = torch.constant.int 1 + %int1_10794 = torch.constant.int 1 + %8732 = torch.prim.ListConstruct %int4_10791, %int1_10792, %int1_10793, %int1_10794 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8733 = torch.aten.repeat %8731, %8732 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %8733, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %8734 = torch.aten.mul.Tensor %8674, %8727 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8734, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int3_10795 = torch.constant.int 3 + %int0_10796 = torch.constant.int 0 + %int64_10797 = torch.constant.int 64 + %int1_10798 = torch.constant.int 1 + %8735 = torch.aten.slice.Tensor %8674, %int3_10795, %int0_10796, %int64_10797, %int1_10798 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %8735, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %int3_10799 = torch.constant.int 3 + %int64_10800 = torch.constant.int 64 + %int9223372036854775807_10801 = torch.constant.int 9223372036854775807 + %int1_10802 = torch.constant.int 1 + %8736 = torch.aten.slice.Tensor %8674, %int3_10799, %int64_10800, %int9223372036854775807_10801, %int1_10802 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %8736, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %8737 = torch.aten.neg %8736 : !torch.vtensor<[4,?,32,64],f16> -> !torch.vtensor<[4,?,32,64],f16> + torch.bind_symbolic_shape %8737, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 64)> : !torch.vtensor<[4,?,32,64],f16> + %8738 = torch.prim.ListConstruct %8737, %8735 : (!torch.vtensor<[4,?,32,64],f16>, !torch.vtensor<[4,?,32,64],f16>) -> !torch.list + %int-1_10803 = torch.constant.int -1 + %8739 = torch.aten.cat %8738, %int-1_10803 : !torch.list, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8739, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %8740 = torch.aten.mul.Tensor %8739, %8733 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8740, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_10804 = torch.constant.int 1 + %8741 = torch.aten.add.Tensor %8734, %8740, %int1_10804 : !torch.vtensor<[4,?,32,128],f16>, !torch.vtensor<[4,?,32,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8741, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int131072_10805 = torch.constant.int 131072 + %none_10806 = torch.constant.none + %none_10807 = torch.constant.none + %cpu_10808 = torch.constant.device "cpu" + %false_10809 = torch.constant.bool false + %8742 = torch.aten.arange %int131072_10805, %none_10806, %none_10807, %cpu_10808, %false_10809 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %int0_10810 = torch.constant.int 0 + %int128_10811 = torch.constant.int 128 + %int2_10812 = torch.constant.int 2 + %int4_10813 = torch.constant.int 4 + %none_10814 = torch.constant.none + %cpu_10815 = torch.constant.device "cpu" + %false_10816 = torch.constant.bool false + %8743 = torch.aten.arange.start_step %int0_10810, %int128_10811, %int2_10812, %int4_10813, %none_10814, %cpu_10815, %false_10816 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6_10817 = torch.constant.int 6 + %8744 = torch.prims.convert_element_type %8743, %int6_10817 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_10818 = torch.constant.int 128 + %8745 = torch.aten.div.Scalar %8744, %int128_10818 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float5.000000e05_10819 = torch.constant.float 5.000000e+05 + %8746 = torch.aten.pow.Scalar %float5.000000e05_10819, %8745 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8747 = torch.aten.reciprocal %8746 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float1.000000e00_10820 = torch.constant.float 1.000000e+00 + %8748 = torch.aten.mul.Scalar %8747, %float1.000000e00_10820 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %8749 = torch.aten.reciprocal %8748 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00_10821 = torch.constant.float 6.2831853071795862 + %8750 = torch.aten.mul.Scalar %8749, %float6.283190e00_10821 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03_10822 = torch.constant.float 8.192000e+03 + %8751 = torch.aten.gt.Scalar %8750, %float8.192000e03_10822 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8_10823 = torch.constant.int 8 + %8752 = torch.aten.div.Scalar %8748, %int8_10823 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %8753 = torch.aten.where.self %8751, %8752, %8748 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8754 = torch.aten.reciprocal %8750 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192_10824 = torch.constant.int 8192 + %8755 = torch.aten.mul.Scalar %8754, %int8192_10824 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_10825 = torch.constant.int 1 + %int1_10826 = torch.constant.int 1 + %8756 = torch.aten.sub.Scalar %8755, %int1_10825, %int1_10826 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3_10827 = torch.constant.int 3 + %8757 = torch.aten.div.Scalar %8756, %int3_10827 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_10828 = torch.constant.int 1 + %int1_10829 = torch.constant.int 1 + %8758 = torch.aten.rsub.Scalar %8757, %int1_10828, %int1_10829 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %8759 = torch.aten.mul.Tensor %8758, %8753 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_10830 = torch.constant.int 8 + %8760 = torch.aten.div.Scalar %8759, %int8_10830 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %8761 = torch.aten.mul.Tensor %8757, %8753 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_10831 = torch.constant.int 1 + %8762 = torch.aten.add.Tensor %8760, %8761, %int1_10831 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03_10832 = torch.constant.float 2.048000e+03 + %8763 = torch.aten.lt.Scalar %8750, %float2.048000e03_10832 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %8764 = torch.aten.bitwise_not %8763 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_10833 = torch.constant.float 8.192000e+03 + %8765 = torch.aten.gt.Scalar %8750, %float8.192000e03_10833 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %8766 = torch.aten.bitwise_not %8765 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %8767 = torch.aten.mul.Tensor %8764, %8766 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %8768 = torch.aten.where.self %8767, %8762, %8753 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %8769 = torch.prim.ListConstruct %8768, %8768 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_10834 = torch.constant.int -1 + %8770 = torch.aten.cat %8769, %int-1_10834 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> + %int6_10835 = torch.constant.int 6 + %8771 = torch.prims.convert_element_type %8770, %int6_10835 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_10836 = torch.constant.int 1 + %8772 = torch.aten.unsqueeze %8742, %int1_10836 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_10837 = torch.constant.int 6 + %8773 = torch.prims.convert_element_type %8772, %int6_10837 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_10838 = torch.constant.int 0 + %8774 = torch.aten.unsqueeze %8771, %int0_10838 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_10839 = torch.constant.int 6 + %8775 = torch.prims.convert_element_type %8774, %int6_10839 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %8776 = torch.aten.mul.Tensor %8773, %8775 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %8777 = torch.aten.cos %8776 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_10840 = torch.constant.int 5 + %8778 = torch.prims.convert_element_type %8777, %int5_10840 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %8779 = torch.aten.sin %8776 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> + %int5_10841 = torch.constant.int 5 + %8780 = torch.prims.convert_element_type %8779, %int5_10841 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int0_10842 = torch.constant.int 0 + %int0_10843 = torch.constant.int 0 + %int1_10844 = torch.constant.int 1 + %8781 = torch.aten.slice.Tensor %8778, %int0_10842, %int0_10843, %298, %int1_10844 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8781, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_10845 = torch.constant.int 1 + %int0_10846 = torch.constant.int 0 + %int9223372036854775807_10847 = torch.constant.int 9223372036854775807 + %int1_10848 = torch.constant.int 1 + %8782 = torch.aten.slice.Tensor %8781, %int1_10845, %int0_10846, %int9223372036854775807_10847, %int1_10848 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8782, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_10849 = torch.constant.int 0 + %int0_10850 = torch.constant.int 0 + %int1_10851 = torch.constant.int 1 + %8783 = torch.aten.slice.Tensor %8780, %int0_10849, %int0_10850, %298, %int1_10851 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8783, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int1_10852 = torch.constant.int 1 + %int0_10853 = torch.constant.int 0 + %int9223372036854775807_10854 = torch.constant.int 9223372036854775807 + %int1_10855 = torch.constant.int 1 + %8784 = torch.aten.slice.Tensor %8783, %int1_10852, %int0_10853, %int9223372036854775807_10854, %int1_10855 : !torch.vtensor<[?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %8784, [%294], affine_map<()[s0] -> (s0 * 32, 128)> : !torch.vtensor<[?,128],f16> + %int0_10856 = torch.constant.int 0 + %8785 = torch.aten.unsqueeze %8782, %int0_10856 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8785, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_10857 = torch.constant.int 1 + %int0_10858 = torch.constant.int 0 + %int9223372036854775807_10859 = torch.constant.int 9223372036854775807 + %int1_10860 = torch.constant.int 1 + %8786 = torch.aten.slice.Tensor %8785, %int1_10857, %int0_10858, %int9223372036854775807_10859, %int1_10860 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8786, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_10861 = torch.constant.int 2 + %8787 = torch.aten.unsqueeze %8786, %int2_10861 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8787, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_10862 = torch.constant.int 3 + %int0_10863 = torch.constant.int 0 + %int9223372036854775807_10864 = torch.constant.int 9223372036854775807 + %int1_10865 = torch.constant.int 1 + %8788 = torch.aten.slice.Tensor %8787, %int3_10862, %int0_10863, %int9223372036854775807_10864, %int1_10865 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8788, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_10866 = torch.constant.int 4 + %int1_10867 = torch.constant.int 1 + %int1_10868 = torch.constant.int 1 + %int1_10869 = torch.constant.int 1 + %8789 = torch.prim.ListConstruct %int4_10866, %int1_10867, %int1_10868, %int1_10869 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8790 = torch.aten.repeat %8788, %8789 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %8790, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %int0_10870 = torch.constant.int 0 + %8791 = torch.aten.unsqueeze %8784, %int0_10870 : !torch.vtensor<[?,128],f16>, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8791, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int1_10871 = torch.constant.int 1 + %int0_10872 = torch.constant.int 0 + %int9223372036854775807_10873 = torch.constant.int 9223372036854775807 + %int1_10874 = torch.constant.int 1 + %8792 = torch.aten.slice.Tensor %8791, %int1_10871, %int0_10872, %int9223372036854775807_10873, %int1_10874 : !torch.vtensor<[1,?,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,128],f16> + torch.bind_symbolic_shape %8792, [%294], affine_map<()[s0] -> (1, s0 * 32, 128)> : !torch.vtensor<[1,?,128],f16> + %int2_10875 = torch.constant.int 2 + %8793 = torch.aten.unsqueeze %8792, %int2_10875 : !torch.vtensor<[1,?,128],f16>, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8793, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int3_10876 = torch.constant.int 3 + %int0_10877 = torch.constant.int 0 + %int9223372036854775807_10878 = torch.constant.int 9223372036854775807 + %int1_10879 = torch.constant.int 1 + %8794 = torch.aten.slice.Tensor %8793, %int3_10876, %int0_10877, %int9223372036854775807_10878, %int1_10879 : !torch.vtensor<[1,?,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?,1,128],f16> + torch.bind_symbolic_shape %8794, [%294], affine_map<()[s0] -> (1, s0 * 32, 1, 128)> : !torch.vtensor<[1,?,1,128],f16> + %int4_10880 = torch.constant.int 4 + %int1_10881 = torch.constant.int 1 + %int1_10882 = torch.constant.int 1 + %int1_10883 = torch.constant.int 1 + %8795 = torch.prim.ListConstruct %int4_10880, %int1_10881, %int1_10882, %int1_10883 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8796 = torch.aten.repeat %8794, %8795 : !torch.vtensor<[1,?,1,128],f16>, !torch.list -> !torch.vtensor<[4,?,1,128],f16> + torch.bind_symbolic_shape %8796, [%294], affine_map<()[s0] -> (4, s0 * 32, 1, 128)> : !torch.vtensor<[4,?,1,128],f16> + %8797 = torch.aten.mul.Tensor %8676, %8790 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8797, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int3_10884 = torch.constant.int 3 + %int0_10885 = torch.constant.int 0 + %int64_10886 = torch.constant.int 64 + %int1_10887 = torch.constant.int 1 + %8798 = torch.aten.slice.Tensor %8676, %int3_10884, %int0_10885, %int64_10886, %int1_10887 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %8798, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %int3_10888 = torch.constant.int 3 + %int64_10889 = torch.constant.int 64 + %int9223372036854775807_10890 = torch.constant.int 9223372036854775807 + %int1_10891 = torch.constant.int 1 + %8799 = torch.aten.slice.Tensor %8676, %int3_10888, %int64_10889, %int9223372036854775807_10890, %int1_10891 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %8799, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %8800 = torch.aten.neg %8799 : !torch.vtensor<[4,?,8,64],f16> -> !torch.vtensor<[4,?,8,64],f16> + torch.bind_symbolic_shape %8800, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 64)> : !torch.vtensor<[4,?,8,64],f16> + %8801 = torch.prim.ListConstruct %8800, %8798 : (!torch.vtensor<[4,?,8,64],f16>, !torch.vtensor<[4,?,8,64],f16>) -> !torch.list + %int-1_10892 = torch.constant.int -1 + %8802 = torch.aten.cat %8801, %int-1_10892 : !torch.list, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8802, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %8803 = torch.aten.mul.Tensor %8802, %8796 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,1,128],f16> -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8803, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int1_10893 = torch.constant.int 1 + %8804 = torch.aten.add.Tensor %8797, %8803, %int1_10893 : !torch.vtensor<[4,?,8,128],f16>, !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %8804, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int32_10894 = torch.constant.int 32 + %8805 = torch.aten.mul.Scalar %arg2, %int32_10894 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8805, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int31 = torch.constant.int 31 + %int1_10895 = torch.constant.int 1 + %8806 = torch.aten.add.Scalar %8805, %int31, %int1_10895 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8806, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_10896 = torch.constant.int 2 + %8807 = torch.aten.mul.Scalar %8806, %int2_10896 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8807, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int0_10897 = torch.constant.int 0 + %int1_10898 = torch.constant.int 1 + %8808 = torch.aten.add.Scalar %8807, %int0_10897, %int1_10898 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8808, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %8809 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %8810 = torch.aten.view %8808, %8809 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %8810, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_10899 = torch.constant.int 4 + %int32_10900 = torch.constant.int 32 + %int8_10901 = torch.constant.int 8 + %int128_10902 = torch.constant.int 128 + %8811 = torch.prim.ListConstruct %int4_10899, %296, %int32_10900, %int8_10901, %int128_10902 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8812 = torch.aten.view %8804, %8811 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %8812, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_10903 = torch.constant.int 32 + %int8_10904 = torch.constant.int 8 + %int128_10905 = torch.constant.int 128 + %8813 = torch.prim.ListConstruct %504, %int32_10903, %int8_10904, %int128_10905 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8814 = torch.aten.view %8812, %8813 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %8814, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_10906 = torch.constant.int 1 + %int2_10907 = torch.constant.int 2 + %8815 = torch.aten.transpose.int %8814, %int1_10906, %int2_10907 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8815, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_10908 = torch.constant.int 5 + %8816 = torch.prims.convert_element_type %8815, %int5_10908 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8816, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_10909 = torch.constant.int 32 + %int2_10910 = torch.constant.int 2 + %int8_10911 = torch.constant.int 8 + %int32_10912 = torch.constant.int 32 + %int128_10913 = torch.constant.int 128 + %8817 = torch.prim.ListConstruct %297, %int32_10909, %int2_10910, %int8_10911, %int32_10912, %int128_10913 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8818 = torch.aten.view %8580, %8817 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8818, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_10914 = torch.constant.int 8 + %int32_10915 = torch.constant.int 32 + %int128_10916 = torch.constant.int 128 + %8819 = torch.prim.ListConstruct %497, %int8_10914, %int32_10915, %int128_10916 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8820 = torch.aten.view %8818, %8819 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8820, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %8821 = torch.prim.ListConstruct %8810 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_10917 = torch.constant.bool false + %8822 = torch.aten.index_put %8820, %8821, %8816, %false_10917 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8822, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_10918 = torch.constant.int 32 + %int2_10919 = torch.constant.int 2 + %int8_10920 = torch.constant.int 8 + %int32_10921 = torch.constant.int 32 + %int128_10922 = torch.constant.int 128 + %8823 = torch.prim.ListConstruct %297, %int32_10918, %int2_10919, %int8_10920, %int32_10921, %int128_10922 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8824 = torch.aten.view %8822, %8823 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8824, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_10923 = torch.constant.int 2097152 + %8825 = torch.prim.ListConstruct %297, %int2097152_10923 : (!torch.int, !torch.int) -> !torch.list + %8826 = torch.aten.view %8824, %8825 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %8826, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_10924 = torch.constant.int 32 + %int2_10925 = torch.constant.int 2 + %int8_10926 = torch.constant.int 8 + %int32_10927 = torch.constant.int 32 + %int128_10928 = torch.constant.int 128 + %8827 = torch.prim.ListConstruct %297, %int32_10924, %int2_10925, %int8_10926, %int32_10927, %int128_10928 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8828 = torch.aten.view %8826, %8827 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8828, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int8_10929 = torch.constant.int 8 + %int32_10930 = torch.constant.int 32 + %int128_10931 = torch.constant.int 128 + %8829 = torch.prim.ListConstruct %497, %int8_10929, %int32_10930, %int128_10931 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8830 = torch.aten.view %8828, %8829 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8830, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_10932 = torch.constant.int 32 + %8831 = torch.aten.mul.Scalar %arg2, %int32_10932 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8831, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int31_10933 = torch.constant.int 31 + %int1_10934 = torch.constant.int 1 + %8832 = torch.aten.add.Scalar %8831, %int31_10933, %int1_10934 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8832, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_10935 = torch.constant.int 2 + %8833 = torch.aten.mul.Scalar %8832, %int2_10935 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8833, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int1_10936 = torch.constant.int 1 + %int1_10937 = torch.constant.int 1 + %8834 = torch.aten.add.Scalar %8833, %int1_10936, %int1_10937 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %8834, [%294], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %8835 = torch.prim.ListConstruct %504 : (!torch.int) -> !torch.list + %8836 = torch.aten.view %8834, %8835 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %8836, [%294], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int4_10938 = torch.constant.int 4 + %int32_10939 = torch.constant.int 32 + %int8_10940 = torch.constant.int 8 + %int128_10941 = torch.constant.int 128 + %8837 = torch.prim.ListConstruct %int4_10938, %296, %int32_10939, %int8_10940, %int128_10941 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8838 = torch.aten.view %8678, %8837 : !torch.vtensor<[4,?,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %8838, [%294], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_10942 = torch.constant.int 32 + %int8_10943 = torch.constant.int 8 + %int128_10944 = torch.constant.int 128 + %8839 = torch.prim.ListConstruct %504, %int32_10942, %int8_10943, %int128_10944 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8840 = torch.aten.view %8838, %8839 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,8,128],f16> + torch.bind_symbolic_shape %8840, [%294], affine_map<()[s0] -> (s0 * 4, 32, 8, 128)> : !torch.vtensor<[?,32,8,128],f16> + %int1_10945 = torch.constant.int 1 + %int2_10946 = torch.constant.int 2 + %8841 = torch.aten.transpose.int %8840, %int1_10945, %int2_10946 : !torch.vtensor<[?,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8841, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int5_10947 = torch.constant.int 5 + %8842 = torch.prims.convert_element_type %8841, %int5_10947 : !torch.vtensor<[?,8,32,128],f16>, !torch.int -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8842, [%294], affine_map<()[s0] -> (s0 * 4, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %8843 = torch.prim.ListConstruct %8836 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false_10948 = torch.constant.bool false + %8844 = torch.aten.index_put %8830, %8843, %8842, %false_10948 : !torch.vtensor<[?,8,32,128],f16>, !torch.list>, !torch.vtensor<[?,8,32,128],f16>, !torch.bool -> !torch.vtensor<[?,8,32,128],f16> + torch.bind_symbolic_shape %8844, [%295], affine_map<()[s0] -> (s0 * 64, 8, 32, 128)> : !torch.vtensor<[?,8,32,128],f16> + %int32_10949 = torch.constant.int 32 + %int2_10950 = torch.constant.int 2 + %int8_10951 = torch.constant.int 8 + %int32_10952 = torch.constant.int 32 + %int128_10953 = torch.constant.int 128 + %8845 = torch.prim.ListConstruct %297, %int32_10949, %int2_10950, %int8_10951, %int32_10952, %int128_10953 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8846 = torch.aten.view %8844, %8845 : !torch.vtensor<[?,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %8846, [%295], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_10954 = torch.constant.int 2097152 + %8847 = torch.prim.ListConstruct %297, %int2097152_10954 : (!torch.int, !torch.int) -> !torch.list + %8848 = torch.aten.view %8846, %8847 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.overwrite.tensor.contents %8848 overwrites %arg3 : !torch.vtensor<[?,2097152],f16>, !torch.tensor<[?,2097152],f16> + torch.bind_symbolic_shape %8848, [%295], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int-2_10955 = torch.constant.int -2 + %8849 = torch.aten.unsqueeze %8804, %int-2_10955 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %8849, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_10956 = torch.constant.int 4 + %int8_10957 = torch.constant.int 8 + %int4_10958 = torch.constant.int 4 + %int128_10959 = torch.constant.int 128 + %8850 = torch.prim.ListConstruct %int4_10956, %298, %int8_10957, %int4_10958, %int128_10959 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_10960 = torch.constant.bool false + %8851 = torch.aten.expand %8849, %8850, %false_10960 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8851, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_10961 = torch.constant.int 0 + %8852 = torch.aten.clone %8851, %int0_10961 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8852, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_10962 = torch.constant.int 4 + %int32_10963 = torch.constant.int 32 + %int128_10964 = torch.constant.int 128 + %8853 = torch.prim.ListConstruct %int4_10962, %298, %int32_10963, %int128_10964 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8854 = torch.aten._unsafe_view %8852, %8853 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8854, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_10965 = torch.constant.int -2 + %8855 = torch.aten.unsqueeze %8678, %int-2_10965 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %8855, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_10966 = torch.constant.int 4 + %int8_10967 = torch.constant.int 8 + %int4_10968 = torch.constant.int 4 + %int128_10969 = torch.constant.int 128 + %8856 = torch.prim.ListConstruct %int4_10966, %298, %int8_10967, %int4_10968, %int128_10969 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_10970 = torch.constant.bool false + %8857 = torch.aten.expand %8855, %8856, %false_10970 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8857, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_10971 = torch.constant.int 0 + %8858 = torch.aten.clone %8857, %int0_10971 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %8858, [%294], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_10972 = torch.constant.int 4 + %int32_10973 = torch.constant.int 32 + %int128_10974 = torch.constant.int 128 + %8859 = torch.prim.ListConstruct %int4_10972, %298, %int32_10973, %int128_10974 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %8860 = torch.aten._unsafe_view %8858, %8859 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8860, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_10975 = torch.constant.int 1 + %int2_10976 = torch.constant.int 2 + %8861 = torch.aten.transpose.int %8741, %int1_10975, %int2_10976 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %8861, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_10977 = torch.constant.int 1 + %int2_10978 = torch.constant.int 2 + %8862 = torch.aten.transpose.int %8854, %int1_10977, %int2_10978 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %8862, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_10979 = torch.constant.int 1 + %int2_10980 = torch.constant.int 2 + %8863 = torch.aten.transpose.int %8860, %int1_10979, %int2_10980 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %8863, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_10981 = torch.constant.float 0.000000e+00 + %false_10982 = torch.constant.bool false + %none_10983 = torch.constant.none + %8864:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%8861, %8862, %8863, %float0.000000e00_10981, %false_10982, %327, %none_10983) : (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,?,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?],f32>) + torch.bind_symbolic_shape %8864#0, [%294], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_10984 = torch.constant.int 1 + %int2_10985 = torch.constant.int 2 + %8865 = torch.aten.transpose.int %8864#0, %int1_10984, %int2_10985 : !torch.vtensor<[4,32,?,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %8865, [%294], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4_10986 = torch.constant.int 4 + %int4096_10987 = torch.constant.int 4096 + %8866 = torch.prim.ListConstruct %int4_10986, %298, %int4096_10987 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8867 = torch.aten.view %8865, %8866 : !torch.vtensor<[4,?,32,128],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8867, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_10988 = torch.constant.int -2 + %int-1_10989 = torch.constant.int -1 + %8868 = torch.aten.transpose.int %285, %int-2_10988, %int-1_10989 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_10990 = torch.constant.int 5 + %8869 = torch.prims.convert_element_type %8868, %int5_10990 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4096_10991 = torch.constant.int 4096 + %8870 = torch.prim.ListConstruct %342, %int4096_10991 : (!torch.int, !torch.int) -> !torch.list + %8871 = torch.aten.view %8867, %8870 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8871, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8872 = torch.aten.mm %8871, %8869 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8872, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_10992 = torch.constant.int 4 + %int4096_10993 = torch.constant.int 4096 + %8873 = torch.prim.ListConstruct %int4_10992, %298, %int4096_10993 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8874 = torch.aten.view %8872, %8873 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8874, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_10994 = torch.constant.int 1 + %8875 = torch.aten.add.Tensor %8641, %8874, %int1_10994 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8875, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_10995 = torch.constant.int 6 + %8876 = torch.prims.convert_element_type %8875, %int6_10995 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8876, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_10996 = torch.constant.int 2 + %8877 = torch.aten.pow.Tensor_Scalar %8876, %int2_10996 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8877, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_10997 = torch.constant.int -1 + %8878 = torch.prim.ListConstruct %int-1_10997 : (!torch.int) -> !torch.list + %true_10998 = torch.constant.bool true + %none_10999 = torch.constant.none + %8879 = torch.aten.mean.dim %8877, %8878, %true_10998, %none_10999 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8879, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_11000 = torch.constant.float 9.9999997473787516E-6 + %int1_11001 = torch.constant.int 1 + %8880 = torch.aten.add.Scalar %8879, %float9.999990e-06_11000, %int1_11001 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8880, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8881 = torch.aten.rsqrt %8880 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8881, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8882 = torch.aten.mul.Tensor %8876, %8881 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8882, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_11002 = torch.constant.int 5 + %8883 = torch.prims.convert_element_type %8882, %int5_11002 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8883, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %8884 = torch.aten.mul.Tensor %286, %8883 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8884, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_11003 = torch.constant.int 5 + %8885 = torch.prims.convert_element_type %8884, %int5_11003 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8885, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_11004 = torch.constant.int -2 + %int-1_11005 = torch.constant.int -1 + %8886 = torch.aten.transpose.int %287, %int-2_11004, %int-1_11005 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_11006 = torch.constant.int 5 + %8887 = torch.prims.convert_element_type %8886, %int5_11006 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_11007 = torch.constant.int 4096 + %8888 = torch.prim.ListConstruct %342, %int4096_11007 : (!torch.int, !torch.int) -> !torch.list + %8889 = torch.aten.view %8885, %8888 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8889, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8890 = torch.aten.mm %8889, %8887 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %8890, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_11008 = torch.constant.int 4 + %int14336_11009 = torch.constant.int 14336 + %8891 = torch.prim.ListConstruct %int4_11008, %298, %int14336_11009 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8892 = torch.aten.view %8890, %8891 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8892, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %8893 = torch.aten.silu %8892 : !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8893, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_11010 = torch.constant.int -2 + %int-1_11011 = torch.constant.int -1 + %8894 = torch.aten.transpose.int %288, %int-2_11010, %int-1_11011 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_11012 = torch.constant.int 5 + %8895 = torch.prims.convert_element_type %8894, %int5_11012 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_11013 = torch.constant.int 4096 + %8896 = torch.prim.ListConstruct %342, %int4096_11013 : (!torch.int, !torch.int) -> !torch.list + %8897 = torch.aten.view %8885, %8896 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8897, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8898 = torch.aten.mm %8897, %8895 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %8898, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %int4_11014 = torch.constant.int 4 + %int14336_11015 = torch.constant.int 14336 + %8899 = torch.prim.ListConstruct %int4_11014, %298, %int14336_11015 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8900 = torch.aten.view %8898, %8899 : !torch.vtensor<[?,14336],f16>, !torch.list -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8900, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %8901 = torch.aten.mul.Tensor %8893, %8900 : !torch.vtensor<[4,?,14336],f16>, !torch.vtensor<[4,?,14336],f16> -> !torch.vtensor<[4,?,14336],f16> + torch.bind_symbolic_shape %8901, [%294], affine_map<()[s0] -> (4, s0 * 32, 14336)> : !torch.vtensor<[4,?,14336],f16> + %int-2_11016 = torch.constant.int -2 + %int-1_11017 = torch.constant.int -1 + %8902 = torch.aten.transpose.int %289, %int-2_11016, %int-1_11017 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_11018 = torch.constant.int 5 + %8903 = torch.prims.convert_element_type %8902, %int5_11018 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int14336_11019 = torch.constant.int 14336 + %8904 = torch.prim.ListConstruct %342, %int14336_11019 : (!torch.int, !torch.int) -> !torch.list + %8905 = torch.aten.view %8901, %8904 : !torch.vtensor<[4,?,14336],f16>, !torch.list -> !torch.vtensor<[?,14336],f16> + torch.bind_symbolic_shape %8905, [%294], affine_map<()[s0] -> (s0 * 128, 14336)> : !torch.vtensor<[?,14336],f16> + %8906 = torch.aten.mm %8905, %8903 : !torch.vtensor<[?,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8906, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %int4_11020 = torch.constant.int 4 + %int4096_11021 = torch.constant.int 4096 + %8907 = torch.prim.ListConstruct %int4_11020, %298, %int4096_11021 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8908 = torch.aten.view %8906, %8907 : !torch.vtensor<[?,4096],f16>, !torch.list -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8908, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int1_11022 = torch.constant.int 1 + %8909 = torch.aten.add.Tensor %8875, %8908, %int1_11022 : !torch.vtensor<[4,?,4096],f16>, !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8909, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int6_11023 = torch.constant.int 6 + %8910 = torch.prims.convert_element_type %8909, %int6_11023 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8910, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int2_11024 = torch.constant.int 2 + %8911 = torch.aten.pow.Tensor_Scalar %8910, %int2_11024 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8911, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int-1_11025 = torch.constant.int -1 + %8912 = torch.prim.ListConstruct %int-1_11025 : (!torch.int) -> !torch.list + %true_11026 = torch.constant.bool true + %none_11027 = torch.constant.none + %8913 = torch.aten.mean.dim %8911, %8912, %true_11026, %none_11027 : !torch.vtensor<[4,?,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8913, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %float9.999990e-06_11028 = torch.constant.float 9.9999997473787516E-6 + %int1_11029 = torch.constant.int 1 + %8914 = torch.aten.add.Scalar %8913, %float9.999990e-06_11028, %int1_11029 : !torch.vtensor<[4,?,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8914, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8915 = torch.aten.rsqrt %8914 : !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,1],f32> + torch.bind_symbolic_shape %8915, [%294], affine_map<()[s0] -> (4, s0 * 32, 1)> : !torch.vtensor<[4,?,1],f32> + %8916 = torch.aten.mul.Tensor %8910, %8915 : !torch.vtensor<[4,?,4096],f32>, !torch.vtensor<[4,?,1],f32> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8916, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_11030 = torch.constant.int 5 + %8917 = torch.prims.convert_element_type %8916, %int5_11030 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8917, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %8918 = torch.aten.mul.Tensor %290, %8917 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,?,4096],f16> -> !torch.vtensor<[4,?,4096],f32> + torch.bind_symbolic_shape %8918, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f32> + %int5_11031 = torch.constant.int 5 + %8919 = torch.prims.convert_element_type %8918, %int5_11031 : !torch.vtensor<[4,?,4096],f32>, !torch.int -> !torch.vtensor<[4,?,4096],f16> + torch.bind_symbolic_shape %8919, [%294], affine_map<()[s0] -> (4, s0 * 32, 4096)> : !torch.vtensor<[4,?,4096],f16> + %int-2_11032 = torch.constant.int -2 + %int-1_11033 = torch.constant.int -1 + %8920 = torch.aten.transpose.int %291, %int-2_11032, %int-1_11033 : !torch.vtensor<[128256,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,128256],f16> + %int5_11034 = torch.constant.int 5 + %8921 = torch.prims.convert_element_type %8920, %int5_11034 : !torch.vtensor<[4096,128256],f16>, !torch.int -> !torch.vtensor<[4096,128256],f16> + %int4096_11035 = torch.constant.int 4096 + %8922 = torch.prim.ListConstruct %342, %int4096_11035 : (!torch.int, !torch.int) -> !torch.list + %8923 = torch.aten.view %8919, %8922 : !torch.vtensor<[4,?,4096],f16>, !torch.list -> !torch.vtensor<[?,4096],f16> + torch.bind_symbolic_shape %8923, [%294], affine_map<()[s0] -> (s0 * 128, 4096)> : !torch.vtensor<[?,4096],f16> + %8924 = torch.aten.mm %8923, %8921 : !torch.vtensor<[?,4096],f16>, !torch.vtensor<[4096,128256],f16> -> !torch.vtensor<[?,128256],f16> + torch.bind_symbolic_shape %8924, [%294], affine_map<()[s0] -> (s0 * 128, 128256)> : !torch.vtensor<[?,128256],f16> + %int4_11036 = torch.constant.int 4 %int128256 = torch.constant.int 128256 - %6617 = torch.prim.ListConstruct %int4_8094, %306, %int128256 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6618 = torch.aten.view %6616, %6617 : !torch.vtensor<[?,128256],f16>, !torch.list -> !torch.vtensor<[4,?,128256],f16> - torch.bind_symbolic_shape %6618, [%292], affine_map<()[s0] -> (4, s0 * 32, 128256)> : !torch.vtensor<[4,?,128256],f16> - return %6618 : !torch.vtensor<[4,?,128256],f16> + %8925 = torch.prim.ListConstruct %int4_11036, %298, %int128256 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %8926 = torch.aten.view %8924, %8925 : !torch.vtensor<[?,128256],f16>, !torch.list -> !torch.vtensor<[4,?,128256],f16> + torch.bind_symbolic_shape %8926, [%294], affine_map<()[s0] -> (4, s0 * 32, 128256)> : !torch.vtensor<[4,?,128256],f16> + return %8926 : !torch.vtensor<[4,?,128256],f16> } func.func @decode_bs4(%arg0: !torch.vtensor<[4,1],si64>, %arg1: !torch.vtensor<[4],si64>, %arg2: !torch.vtensor<[4],si64>, %arg3: !torch.vtensor<[4,?],si64>, %arg4: !torch.tensor<[?,2097152],f16>) -> !torch.vtensor<[4,1,128256],f16> attributes {torch.assume_strict_symbolic_shapes} { + %0 = torch.vtensor.literal(dense<0xFC00> : tensor) : !torch.vtensor<[],f16> %__auto.token_embd.weight = util.global.load @__auto.token_embd.weight : tensor<128256x4096xf16> - %0 = torch_c.from_builtin_tensor %__auto.token_embd.weight : tensor<128256x4096xf16> -> !torch.vtensor<[128256,4096],f16> + %1 = torch_c.from_builtin_tensor %__auto.token_embd.weight : tensor<128256x4096xf16> -> !torch.vtensor<[128256,4096],f16> %__auto.blk.0.attn_norm.weight = util.global.load @__auto.blk.0.attn_norm.weight : tensor<4096xf32> - %1 = torch_c.from_builtin_tensor %__auto.blk.0.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %2 = torch_c.from_builtin_tensor %__auto.blk.0.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.0.attn_q.weight = util.global.load @__auto.blk.0.attn_q.weight : tensor<4096x4096xf16> - %2 = torch_c.from_builtin_tensor %__auto.blk.0.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %3 = torch_c.from_builtin_tensor %__auto.blk.0.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.0.attn_k.weight = util.global.load @__auto.blk.0.attn_k.weight : tensor<1024x4096xf16> - %3 = torch_c.from_builtin_tensor %__auto.blk.0.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %4 = torch_c.from_builtin_tensor %__auto.blk.0.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.0.attn_v.weight = util.global.load @__auto.blk.0.attn_v.weight : tensor<1024x4096xf16> - %4 = torch_c.from_builtin_tensor %__auto.blk.0.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %5 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %6 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %5 = torch_c.from_builtin_tensor %__auto.blk.0.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %6 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %7 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %8 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %9 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %10 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.0.attn_output.weight = util.global.load @__auto.blk.0.attn_output.weight : tensor<4096x4096xf16> - %7 = torch_c.from_builtin_tensor %__auto.blk.0.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %11 = torch_c.from_builtin_tensor %__auto.blk.0.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.0.ffn_norm.weight = util.global.load @__auto.blk.0.ffn_norm.weight : tensor<4096xf32> - %8 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %12 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.0.ffn_gate.weight = util.global.load @__auto.blk.0.ffn_gate.weight : tensor<14336x4096xf16> - %9 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %13 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.0.ffn_up.weight = util.global.load @__auto.blk.0.ffn_up.weight : tensor<14336x4096xf16> - %10 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %14 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.0.ffn_down.weight = util.global.load @__auto.blk.0.ffn_down.weight : tensor<4096x14336xf16> - %11 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %15 = torch_c.from_builtin_tensor %__auto.blk.0.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.1.attn_norm.weight = util.global.load @__auto.blk.1.attn_norm.weight : tensor<4096xf32> - %12 = torch_c.from_builtin_tensor %__auto.blk.1.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %16 = torch_c.from_builtin_tensor %__auto.blk.1.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.1.attn_q.weight = util.global.load @__auto.blk.1.attn_q.weight : tensor<4096x4096xf16> - %13 = torch_c.from_builtin_tensor %__auto.blk.1.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %17 = torch_c.from_builtin_tensor %__auto.blk.1.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.1.attn_k.weight = util.global.load @__auto.blk.1.attn_k.weight : tensor<1024x4096xf16> - %14 = torch_c.from_builtin_tensor %__auto.blk.1.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %18 = torch_c.from_builtin_tensor %__auto.blk.1.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.1.attn_v.weight = util.global.load @__auto.blk.1.attn_v.weight : tensor<1024x4096xf16> - %15 = torch_c.from_builtin_tensor %__auto.blk.1.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %16 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %17 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %19 = torch_c.from_builtin_tensor %__auto.blk.1.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %20 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %21 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %22 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %23 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %24 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.1.attn_output.weight = util.global.load @__auto.blk.1.attn_output.weight : tensor<4096x4096xf16> - %18 = torch_c.from_builtin_tensor %__auto.blk.1.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %25 = torch_c.from_builtin_tensor %__auto.blk.1.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.1.ffn_norm.weight = util.global.load @__auto.blk.1.ffn_norm.weight : tensor<4096xf32> - %19 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %26 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.1.ffn_gate.weight = util.global.load @__auto.blk.1.ffn_gate.weight : tensor<14336x4096xf16> - %20 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %27 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.1.ffn_up.weight = util.global.load @__auto.blk.1.ffn_up.weight : tensor<14336x4096xf16> - %21 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %28 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.1.ffn_down.weight = util.global.load @__auto.blk.1.ffn_down.weight : tensor<4096x14336xf16> - %22 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %29 = torch_c.from_builtin_tensor %__auto.blk.1.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.2.attn_norm.weight = util.global.load @__auto.blk.2.attn_norm.weight : tensor<4096xf32> - %23 = torch_c.from_builtin_tensor %__auto.blk.2.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %30 = torch_c.from_builtin_tensor %__auto.blk.2.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.2.attn_q.weight = util.global.load @__auto.blk.2.attn_q.weight : tensor<4096x4096xf16> - %24 = torch_c.from_builtin_tensor %__auto.blk.2.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %31 = torch_c.from_builtin_tensor %__auto.blk.2.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.2.attn_k.weight = util.global.load @__auto.blk.2.attn_k.weight : tensor<1024x4096xf16> - %25 = torch_c.from_builtin_tensor %__auto.blk.2.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %32 = torch_c.from_builtin_tensor %__auto.blk.2.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.2.attn_v.weight = util.global.load @__auto.blk.2.attn_v.weight : tensor<1024x4096xf16> - %26 = torch_c.from_builtin_tensor %__auto.blk.2.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %27 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %28 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %33 = torch_c.from_builtin_tensor %__auto.blk.2.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %34 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %35 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %36 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> + %37 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %38 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.2.attn_output.weight = util.global.load @__auto.blk.2.attn_output.weight : tensor<4096x4096xf16> - %29 = torch_c.from_builtin_tensor %__auto.blk.2.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %39 = torch_c.from_builtin_tensor %__auto.blk.2.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.2.ffn_norm.weight = util.global.load @__auto.blk.2.ffn_norm.weight : tensor<4096xf32> - %30 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %40 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.2.ffn_gate.weight = util.global.load @__auto.blk.2.ffn_gate.weight : tensor<14336x4096xf16> - %31 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %41 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.2.ffn_up.weight = util.global.load @__auto.blk.2.ffn_up.weight : tensor<14336x4096xf16> - %32 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %42 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.2.ffn_down.weight = util.global.load @__auto.blk.2.ffn_down.weight : tensor<4096x14336xf16> - %33 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %43 = torch_c.from_builtin_tensor %__auto.blk.2.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.3.attn_norm.weight = util.global.load @__auto.blk.3.attn_norm.weight : tensor<4096xf32> - %34 = torch_c.from_builtin_tensor %__auto.blk.3.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %44 = torch_c.from_builtin_tensor %__auto.blk.3.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.3.attn_q.weight = util.global.load @__auto.blk.3.attn_q.weight : tensor<4096x4096xf16> - %35 = torch_c.from_builtin_tensor %__auto.blk.3.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %45 = torch_c.from_builtin_tensor %__auto.blk.3.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.3.attn_k.weight = util.global.load @__auto.blk.3.attn_k.weight : tensor<1024x4096xf16> - %36 = torch_c.from_builtin_tensor %__auto.blk.3.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %46 = torch_c.from_builtin_tensor %__auto.blk.3.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.3.attn_v.weight = util.global.load @__auto.blk.3.attn_v.weight : tensor<1024x4096xf16> - %37 = torch_c.from_builtin_tensor %__auto.blk.3.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %38 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %39 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %47 = torch_c.from_builtin_tensor %__auto.blk.3.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %48 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %49 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %50 = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> + %51 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %52 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.3.attn_output.weight = util.global.load @__auto.blk.3.attn_output.weight : tensor<4096x4096xf16> - %40 = torch_c.from_builtin_tensor %__auto.blk.3.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %53 = torch_c.from_builtin_tensor %__auto.blk.3.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.3.ffn_norm.weight = util.global.load @__auto.blk.3.ffn_norm.weight : tensor<4096xf32> - %41 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %54 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.3.ffn_gate.weight = util.global.load @__auto.blk.3.ffn_gate.weight : tensor<14336x4096xf16> - %42 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %55 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.3.ffn_up.weight = util.global.load @__auto.blk.3.ffn_up.weight : tensor<14336x4096xf16> - %43 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %56 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.3.ffn_down.weight = util.global.load @__auto.blk.3.ffn_down.weight : tensor<4096x14336xf16> - %44 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %57 = torch_c.from_builtin_tensor %__auto.blk.3.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.4.attn_norm.weight = util.global.load @__auto.blk.4.attn_norm.weight : tensor<4096xf32> - %45 = torch_c.from_builtin_tensor %__auto.blk.4.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %58 = torch_c.from_builtin_tensor %__auto.blk.4.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.4.attn_q.weight = util.global.load @__auto.blk.4.attn_q.weight : tensor<4096x4096xf16> - %46 = torch_c.from_builtin_tensor %__auto.blk.4.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %59 = torch_c.from_builtin_tensor %__auto.blk.4.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.4.attn_k.weight = util.global.load @__auto.blk.4.attn_k.weight : tensor<1024x4096xf16> - %47 = torch_c.from_builtin_tensor %__auto.blk.4.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %60 = torch_c.from_builtin_tensor %__auto.blk.4.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.4.attn_v.weight = util.global.load @__auto.blk.4.attn_v.weight : tensor<1024x4096xf16> - %48 = torch_c.from_builtin_tensor %__auto.blk.4.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %49 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %50 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %61 = torch_c.from_builtin_tensor %__auto.blk.4.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %62 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %63 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %64 = torch.vtensor.literal(dense<4> : tensor) : !torch.vtensor<[],si64> + %65 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %66 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.4.attn_output.weight = util.global.load @__auto.blk.4.attn_output.weight : tensor<4096x4096xf16> - %51 = torch_c.from_builtin_tensor %__auto.blk.4.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %67 = torch_c.from_builtin_tensor %__auto.blk.4.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.4.ffn_norm.weight = util.global.load @__auto.blk.4.ffn_norm.weight : tensor<4096xf32> - %52 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %68 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.4.ffn_gate.weight = util.global.load @__auto.blk.4.ffn_gate.weight : tensor<14336x4096xf16> - %53 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %69 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.4.ffn_up.weight = util.global.load @__auto.blk.4.ffn_up.weight : tensor<14336x4096xf16> - %54 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %70 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.4.ffn_down.weight = util.global.load @__auto.blk.4.ffn_down.weight : tensor<4096x14336xf16> - %55 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %71 = torch_c.from_builtin_tensor %__auto.blk.4.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.5.attn_norm.weight = util.global.load @__auto.blk.5.attn_norm.weight : tensor<4096xf32> - %56 = torch_c.from_builtin_tensor %__auto.blk.5.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %72 = torch_c.from_builtin_tensor %__auto.blk.5.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.5.attn_q.weight = util.global.load @__auto.blk.5.attn_q.weight : tensor<4096x4096xf16> - %57 = torch_c.from_builtin_tensor %__auto.blk.5.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %73 = torch_c.from_builtin_tensor %__auto.blk.5.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.5.attn_k.weight = util.global.load @__auto.blk.5.attn_k.weight : tensor<1024x4096xf16> - %58 = torch_c.from_builtin_tensor %__auto.blk.5.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %74 = torch_c.from_builtin_tensor %__auto.blk.5.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.5.attn_v.weight = util.global.load @__auto.blk.5.attn_v.weight : tensor<1024x4096xf16> - %59 = torch_c.from_builtin_tensor %__auto.blk.5.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %60 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %61 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %75 = torch_c.from_builtin_tensor %__auto.blk.5.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %76 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %77 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %78 = torch.vtensor.literal(dense<5> : tensor) : !torch.vtensor<[],si64> + %79 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %80 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.5.attn_output.weight = util.global.load @__auto.blk.5.attn_output.weight : tensor<4096x4096xf16> - %62 = torch_c.from_builtin_tensor %__auto.blk.5.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %81 = torch_c.from_builtin_tensor %__auto.blk.5.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.5.ffn_norm.weight = util.global.load @__auto.blk.5.ffn_norm.weight : tensor<4096xf32> - %63 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %82 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.5.ffn_gate.weight = util.global.load @__auto.blk.5.ffn_gate.weight : tensor<14336x4096xf16> - %64 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %83 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.5.ffn_up.weight = util.global.load @__auto.blk.5.ffn_up.weight : tensor<14336x4096xf16> - %65 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %84 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.5.ffn_down.weight = util.global.load @__auto.blk.5.ffn_down.weight : tensor<4096x14336xf16> - %66 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %85 = torch_c.from_builtin_tensor %__auto.blk.5.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.6.attn_norm.weight = util.global.load @__auto.blk.6.attn_norm.weight : tensor<4096xf32> - %67 = torch_c.from_builtin_tensor %__auto.blk.6.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %86 = torch_c.from_builtin_tensor %__auto.blk.6.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.6.attn_q.weight = util.global.load @__auto.blk.6.attn_q.weight : tensor<4096x4096xf16> - %68 = torch_c.from_builtin_tensor %__auto.blk.6.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %87 = torch_c.from_builtin_tensor %__auto.blk.6.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.6.attn_k.weight = util.global.load @__auto.blk.6.attn_k.weight : tensor<1024x4096xf16> - %69 = torch_c.from_builtin_tensor %__auto.blk.6.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %88 = torch_c.from_builtin_tensor %__auto.blk.6.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.6.attn_v.weight = util.global.load @__auto.blk.6.attn_v.weight : tensor<1024x4096xf16> - %70 = torch_c.from_builtin_tensor %__auto.blk.6.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %71 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %72 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %89 = torch_c.from_builtin_tensor %__auto.blk.6.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %90 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %91 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %92 = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> + %93 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %94 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.6.attn_output.weight = util.global.load @__auto.blk.6.attn_output.weight : tensor<4096x4096xf16> - %73 = torch_c.from_builtin_tensor %__auto.blk.6.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %95 = torch_c.from_builtin_tensor %__auto.blk.6.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.6.ffn_norm.weight = util.global.load @__auto.blk.6.ffn_norm.weight : tensor<4096xf32> - %74 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %96 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.6.ffn_gate.weight = util.global.load @__auto.blk.6.ffn_gate.weight : tensor<14336x4096xf16> - %75 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %97 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.6.ffn_up.weight = util.global.load @__auto.blk.6.ffn_up.weight : tensor<14336x4096xf16> - %76 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %98 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.6.ffn_down.weight = util.global.load @__auto.blk.6.ffn_down.weight : tensor<4096x14336xf16> - %77 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %99 = torch_c.from_builtin_tensor %__auto.blk.6.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.7.attn_norm.weight = util.global.load @__auto.blk.7.attn_norm.weight : tensor<4096xf32> - %78 = torch_c.from_builtin_tensor %__auto.blk.7.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %100 = torch_c.from_builtin_tensor %__auto.blk.7.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.7.attn_q.weight = util.global.load @__auto.blk.7.attn_q.weight : tensor<4096x4096xf16> - %79 = torch_c.from_builtin_tensor %__auto.blk.7.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %101 = torch_c.from_builtin_tensor %__auto.blk.7.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.7.attn_k.weight = util.global.load @__auto.blk.7.attn_k.weight : tensor<1024x4096xf16> - %80 = torch_c.from_builtin_tensor %__auto.blk.7.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %102 = torch_c.from_builtin_tensor %__auto.blk.7.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.7.attn_v.weight = util.global.load @__auto.blk.7.attn_v.weight : tensor<1024x4096xf16> - %81 = torch_c.from_builtin_tensor %__auto.blk.7.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %82 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %83 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %103 = torch_c.from_builtin_tensor %__auto.blk.7.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %104 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %105 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %106 = torch.vtensor.literal(dense<7> : tensor) : !torch.vtensor<[],si64> + %107 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %108 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.7.attn_output.weight = util.global.load @__auto.blk.7.attn_output.weight : tensor<4096x4096xf16> - %84 = torch_c.from_builtin_tensor %__auto.blk.7.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %109 = torch_c.from_builtin_tensor %__auto.blk.7.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.7.ffn_norm.weight = util.global.load @__auto.blk.7.ffn_norm.weight : tensor<4096xf32> - %85 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %110 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.7.ffn_gate.weight = util.global.load @__auto.blk.7.ffn_gate.weight : tensor<14336x4096xf16> - %86 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %111 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.7.ffn_up.weight = util.global.load @__auto.blk.7.ffn_up.weight : tensor<14336x4096xf16> - %87 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %112 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.7.ffn_down.weight = util.global.load @__auto.blk.7.ffn_down.weight : tensor<4096x14336xf16> - %88 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %113 = torch_c.from_builtin_tensor %__auto.blk.7.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.8.attn_norm.weight = util.global.load @__auto.blk.8.attn_norm.weight : tensor<4096xf32> - %89 = torch_c.from_builtin_tensor %__auto.blk.8.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %114 = torch_c.from_builtin_tensor %__auto.blk.8.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.8.attn_q.weight = util.global.load @__auto.blk.8.attn_q.weight : tensor<4096x4096xf16> - %90 = torch_c.from_builtin_tensor %__auto.blk.8.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %115 = torch_c.from_builtin_tensor %__auto.blk.8.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.8.attn_k.weight = util.global.load @__auto.blk.8.attn_k.weight : tensor<1024x4096xf16> - %91 = torch_c.from_builtin_tensor %__auto.blk.8.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %116 = torch_c.from_builtin_tensor %__auto.blk.8.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.8.attn_v.weight = util.global.load @__auto.blk.8.attn_v.weight : tensor<1024x4096xf16> - %92 = torch_c.from_builtin_tensor %__auto.blk.8.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %93 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %94 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %117 = torch_c.from_builtin_tensor %__auto.blk.8.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %118 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %119 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %120 = torch.vtensor.literal(dense<8> : tensor) : !torch.vtensor<[],si64> + %121 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %122 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.8.attn_output.weight = util.global.load @__auto.blk.8.attn_output.weight : tensor<4096x4096xf16> - %95 = torch_c.from_builtin_tensor %__auto.blk.8.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %123 = torch_c.from_builtin_tensor %__auto.blk.8.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.8.ffn_norm.weight = util.global.load @__auto.blk.8.ffn_norm.weight : tensor<4096xf32> - %96 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %124 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.8.ffn_gate.weight = util.global.load @__auto.blk.8.ffn_gate.weight : tensor<14336x4096xf16> - %97 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %125 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.8.ffn_up.weight = util.global.load @__auto.blk.8.ffn_up.weight : tensor<14336x4096xf16> - %98 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %126 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.8.ffn_down.weight = util.global.load @__auto.blk.8.ffn_down.weight : tensor<4096x14336xf16> - %99 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %127 = torch_c.from_builtin_tensor %__auto.blk.8.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.9.attn_norm.weight = util.global.load @__auto.blk.9.attn_norm.weight : tensor<4096xf32> - %100 = torch_c.from_builtin_tensor %__auto.blk.9.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %128 = torch_c.from_builtin_tensor %__auto.blk.9.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.9.attn_q.weight = util.global.load @__auto.blk.9.attn_q.weight : tensor<4096x4096xf16> - %101 = torch_c.from_builtin_tensor %__auto.blk.9.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %129 = torch_c.from_builtin_tensor %__auto.blk.9.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.9.attn_k.weight = util.global.load @__auto.blk.9.attn_k.weight : tensor<1024x4096xf16> - %102 = torch_c.from_builtin_tensor %__auto.blk.9.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %130 = torch_c.from_builtin_tensor %__auto.blk.9.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.9.attn_v.weight = util.global.load @__auto.blk.9.attn_v.weight : tensor<1024x4096xf16> - %103 = torch_c.from_builtin_tensor %__auto.blk.9.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %104 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %105 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %131 = torch_c.from_builtin_tensor %__auto.blk.9.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %132 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %133 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %134 = torch.vtensor.literal(dense<9> : tensor) : !torch.vtensor<[],si64> + %135 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %136 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.9.attn_output.weight = util.global.load @__auto.blk.9.attn_output.weight : tensor<4096x4096xf16> - %106 = torch_c.from_builtin_tensor %__auto.blk.9.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %137 = torch_c.from_builtin_tensor %__auto.blk.9.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.9.ffn_norm.weight = util.global.load @__auto.blk.9.ffn_norm.weight : tensor<4096xf32> - %107 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %138 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.9.ffn_gate.weight = util.global.load @__auto.blk.9.ffn_gate.weight : tensor<14336x4096xf16> - %108 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %139 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.9.ffn_up.weight = util.global.load @__auto.blk.9.ffn_up.weight : tensor<14336x4096xf16> - %109 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %140 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.9.ffn_down.weight = util.global.load @__auto.blk.9.ffn_down.weight : tensor<4096x14336xf16> - %110 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %141 = torch_c.from_builtin_tensor %__auto.blk.9.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.10.attn_norm.weight = util.global.load @__auto.blk.10.attn_norm.weight : tensor<4096xf32> - %111 = torch_c.from_builtin_tensor %__auto.blk.10.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %142 = torch_c.from_builtin_tensor %__auto.blk.10.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.10.attn_q.weight = util.global.load @__auto.blk.10.attn_q.weight : tensor<4096x4096xf16> - %112 = torch_c.from_builtin_tensor %__auto.blk.10.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %143 = torch_c.from_builtin_tensor %__auto.blk.10.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.10.attn_k.weight = util.global.load @__auto.blk.10.attn_k.weight : tensor<1024x4096xf16> - %113 = torch_c.from_builtin_tensor %__auto.blk.10.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %144 = torch_c.from_builtin_tensor %__auto.blk.10.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.10.attn_v.weight = util.global.load @__auto.blk.10.attn_v.weight : tensor<1024x4096xf16> - %114 = torch_c.from_builtin_tensor %__auto.blk.10.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %115 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %116 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %145 = torch_c.from_builtin_tensor %__auto.blk.10.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %146 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %147 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %148 = torch.vtensor.literal(dense<10> : tensor) : !torch.vtensor<[],si64> + %149 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %150 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.10.attn_output.weight = util.global.load @__auto.blk.10.attn_output.weight : tensor<4096x4096xf16> - %117 = torch_c.from_builtin_tensor %__auto.blk.10.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %151 = torch_c.from_builtin_tensor %__auto.blk.10.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.10.ffn_norm.weight = util.global.load @__auto.blk.10.ffn_norm.weight : tensor<4096xf32> - %118 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %152 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.10.ffn_gate.weight = util.global.load @__auto.blk.10.ffn_gate.weight : tensor<14336x4096xf16> - %119 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %153 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.10.ffn_up.weight = util.global.load @__auto.blk.10.ffn_up.weight : tensor<14336x4096xf16> - %120 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %154 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.10.ffn_down.weight = util.global.load @__auto.blk.10.ffn_down.weight : tensor<4096x14336xf16> - %121 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %155 = torch_c.from_builtin_tensor %__auto.blk.10.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.11.attn_norm.weight = util.global.load @__auto.blk.11.attn_norm.weight : tensor<4096xf32> - %122 = torch_c.from_builtin_tensor %__auto.blk.11.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %156 = torch_c.from_builtin_tensor %__auto.blk.11.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.11.attn_q.weight = util.global.load @__auto.blk.11.attn_q.weight : tensor<4096x4096xf16> - %123 = torch_c.from_builtin_tensor %__auto.blk.11.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %157 = torch_c.from_builtin_tensor %__auto.blk.11.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.11.attn_k.weight = util.global.load @__auto.blk.11.attn_k.weight : tensor<1024x4096xf16> - %124 = torch_c.from_builtin_tensor %__auto.blk.11.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %158 = torch_c.from_builtin_tensor %__auto.blk.11.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.11.attn_v.weight = util.global.load @__auto.blk.11.attn_v.weight : tensor<1024x4096xf16> - %125 = torch_c.from_builtin_tensor %__auto.blk.11.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %126 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %127 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %159 = torch_c.from_builtin_tensor %__auto.blk.11.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %160 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %161 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %162 = torch.vtensor.literal(dense<11> : tensor) : !torch.vtensor<[],si64> + %163 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %164 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.11.attn_output.weight = util.global.load @__auto.blk.11.attn_output.weight : tensor<4096x4096xf16> - %128 = torch_c.from_builtin_tensor %__auto.blk.11.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %165 = torch_c.from_builtin_tensor %__auto.blk.11.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.11.ffn_norm.weight = util.global.load @__auto.blk.11.ffn_norm.weight : tensor<4096xf32> - %129 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %166 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.11.ffn_gate.weight = util.global.load @__auto.blk.11.ffn_gate.weight : tensor<14336x4096xf16> - %130 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %167 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.11.ffn_up.weight = util.global.load @__auto.blk.11.ffn_up.weight : tensor<14336x4096xf16> - %131 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %168 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.11.ffn_down.weight = util.global.load @__auto.blk.11.ffn_down.weight : tensor<4096x14336xf16> - %132 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %169 = torch_c.from_builtin_tensor %__auto.blk.11.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.12.attn_norm.weight = util.global.load @__auto.blk.12.attn_norm.weight : tensor<4096xf32> - %133 = torch_c.from_builtin_tensor %__auto.blk.12.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %170 = torch_c.from_builtin_tensor %__auto.blk.12.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.12.attn_q.weight = util.global.load @__auto.blk.12.attn_q.weight : tensor<4096x4096xf16> - %134 = torch_c.from_builtin_tensor %__auto.blk.12.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %171 = torch_c.from_builtin_tensor %__auto.blk.12.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.12.attn_k.weight = util.global.load @__auto.blk.12.attn_k.weight : tensor<1024x4096xf16> - %135 = torch_c.from_builtin_tensor %__auto.blk.12.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %172 = torch_c.from_builtin_tensor %__auto.blk.12.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.12.attn_v.weight = util.global.load @__auto.blk.12.attn_v.weight : tensor<1024x4096xf16> - %136 = torch_c.from_builtin_tensor %__auto.blk.12.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %137 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %138 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %173 = torch_c.from_builtin_tensor %__auto.blk.12.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %174 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %175 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %176 = torch.vtensor.literal(dense<12> : tensor) : !torch.vtensor<[],si64> + %177 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %178 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.12.attn_output.weight = util.global.load @__auto.blk.12.attn_output.weight : tensor<4096x4096xf16> - %139 = torch_c.from_builtin_tensor %__auto.blk.12.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %179 = torch_c.from_builtin_tensor %__auto.blk.12.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.12.ffn_norm.weight = util.global.load @__auto.blk.12.ffn_norm.weight : tensor<4096xf32> - %140 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %180 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.12.ffn_gate.weight = util.global.load @__auto.blk.12.ffn_gate.weight : tensor<14336x4096xf16> - %141 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %181 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.12.ffn_up.weight = util.global.load @__auto.blk.12.ffn_up.weight : tensor<14336x4096xf16> - %142 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %182 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.12.ffn_down.weight = util.global.load @__auto.blk.12.ffn_down.weight : tensor<4096x14336xf16> - %143 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %183 = torch_c.from_builtin_tensor %__auto.blk.12.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.13.attn_norm.weight = util.global.load @__auto.blk.13.attn_norm.weight : tensor<4096xf32> - %144 = torch_c.from_builtin_tensor %__auto.blk.13.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %184 = torch_c.from_builtin_tensor %__auto.blk.13.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.13.attn_q.weight = util.global.load @__auto.blk.13.attn_q.weight : tensor<4096x4096xf16> - %145 = torch_c.from_builtin_tensor %__auto.blk.13.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %185 = torch_c.from_builtin_tensor %__auto.blk.13.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.13.attn_k.weight = util.global.load @__auto.blk.13.attn_k.weight : tensor<1024x4096xf16> - %146 = torch_c.from_builtin_tensor %__auto.blk.13.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %186 = torch_c.from_builtin_tensor %__auto.blk.13.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.13.attn_v.weight = util.global.load @__auto.blk.13.attn_v.weight : tensor<1024x4096xf16> - %147 = torch_c.from_builtin_tensor %__auto.blk.13.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %148 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %149 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %187 = torch_c.from_builtin_tensor %__auto.blk.13.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %188 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %189 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %190 = torch.vtensor.literal(dense<13> : tensor) : !torch.vtensor<[],si64> + %191 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %192 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.13.attn_output.weight = util.global.load @__auto.blk.13.attn_output.weight : tensor<4096x4096xf16> - %150 = torch_c.from_builtin_tensor %__auto.blk.13.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %193 = torch_c.from_builtin_tensor %__auto.blk.13.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.13.ffn_norm.weight = util.global.load @__auto.blk.13.ffn_norm.weight : tensor<4096xf32> - %151 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %194 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.13.ffn_gate.weight = util.global.load @__auto.blk.13.ffn_gate.weight : tensor<14336x4096xf16> - %152 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %195 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.13.ffn_up.weight = util.global.load @__auto.blk.13.ffn_up.weight : tensor<14336x4096xf16> - %153 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %196 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.13.ffn_down.weight = util.global.load @__auto.blk.13.ffn_down.weight : tensor<4096x14336xf16> - %154 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %197 = torch_c.from_builtin_tensor %__auto.blk.13.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.14.attn_norm.weight = util.global.load @__auto.blk.14.attn_norm.weight : tensor<4096xf32> - %155 = torch_c.from_builtin_tensor %__auto.blk.14.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %198 = torch_c.from_builtin_tensor %__auto.blk.14.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.14.attn_q.weight = util.global.load @__auto.blk.14.attn_q.weight : tensor<4096x4096xf16> - %156 = torch_c.from_builtin_tensor %__auto.blk.14.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %199 = torch_c.from_builtin_tensor %__auto.blk.14.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.14.attn_k.weight = util.global.load @__auto.blk.14.attn_k.weight : tensor<1024x4096xf16> - %157 = torch_c.from_builtin_tensor %__auto.blk.14.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %200 = torch_c.from_builtin_tensor %__auto.blk.14.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.14.attn_v.weight = util.global.load @__auto.blk.14.attn_v.weight : tensor<1024x4096xf16> - %158 = torch_c.from_builtin_tensor %__auto.blk.14.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %159 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %160 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %201 = torch_c.from_builtin_tensor %__auto.blk.14.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %202 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %203 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %204 = torch.vtensor.literal(dense<14> : tensor) : !torch.vtensor<[],si64> + %205 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %206 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.14.attn_output.weight = util.global.load @__auto.blk.14.attn_output.weight : tensor<4096x4096xf16> - %161 = torch_c.from_builtin_tensor %__auto.blk.14.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %207 = torch_c.from_builtin_tensor %__auto.blk.14.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.14.ffn_norm.weight = util.global.load @__auto.blk.14.ffn_norm.weight : tensor<4096xf32> - %162 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %208 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.14.ffn_gate.weight = util.global.load @__auto.blk.14.ffn_gate.weight : tensor<14336x4096xf16> - %163 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %209 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.14.ffn_up.weight = util.global.load @__auto.blk.14.ffn_up.weight : tensor<14336x4096xf16> - %164 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %210 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.14.ffn_down.weight = util.global.load @__auto.blk.14.ffn_down.weight : tensor<4096x14336xf16> - %165 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %211 = torch_c.from_builtin_tensor %__auto.blk.14.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.15.attn_norm.weight = util.global.load @__auto.blk.15.attn_norm.weight : tensor<4096xf32> - %166 = torch_c.from_builtin_tensor %__auto.blk.15.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %212 = torch_c.from_builtin_tensor %__auto.blk.15.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.15.attn_q.weight = util.global.load @__auto.blk.15.attn_q.weight : tensor<4096x4096xf16> - %167 = torch_c.from_builtin_tensor %__auto.blk.15.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %213 = torch_c.from_builtin_tensor %__auto.blk.15.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.15.attn_k.weight = util.global.load @__auto.blk.15.attn_k.weight : tensor<1024x4096xf16> - %168 = torch_c.from_builtin_tensor %__auto.blk.15.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %214 = torch_c.from_builtin_tensor %__auto.blk.15.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.15.attn_v.weight = util.global.load @__auto.blk.15.attn_v.weight : tensor<1024x4096xf16> - %169 = torch_c.from_builtin_tensor %__auto.blk.15.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %170 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %171 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %215 = torch_c.from_builtin_tensor %__auto.blk.15.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %216 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %217 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %218 = torch.vtensor.literal(dense<15> : tensor) : !torch.vtensor<[],si64> + %219 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %220 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.15.attn_output.weight = util.global.load @__auto.blk.15.attn_output.weight : tensor<4096x4096xf16> - %172 = torch_c.from_builtin_tensor %__auto.blk.15.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %221 = torch_c.from_builtin_tensor %__auto.blk.15.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.15.ffn_norm.weight = util.global.load @__auto.blk.15.ffn_norm.weight : tensor<4096xf32> - %173 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %222 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.15.ffn_gate.weight = util.global.load @__auto.blk.15.ffn_gate.weight : tensor<14336x4096xf16> - %174 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %223 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.15.ffn_up.weight = util.global.load @__auto.blk.15.ffn_up.weight : tensor<14336x4096xf16> - %175 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %224 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.15.ffn_down.weight = util.global.load @__auto.blk.15.ffn_down.weight : tensor<4096x14336xf16> - %176 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %225 = torch_c.from_builtin_tensor %__auto.blk.15.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.16.attn_norm.weight = util.global.load @__auto.blk.16.attn_norm.weight : tensor<4096xf32> - %177 = torch_c.from_builtin_tensor %__auto.blk.16.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %226 = torch_c.from_builtin_tensor %__auto.blk.16.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.16.attn_q.weight = util.global.load @__auto.blk.16.attn_q.weight : tensor<4096x4096xf16> - %178 = torch_c.from_builtin_tensor %__auto.blk.16.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %227 = torch_c.from_builtin_tensor %__auto.blk.16.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.16.attn_k.weight = util.global.load @__auto.blk.16.attn_k.weight : tensor<1024x4096xf16> - %179 = torch_c.from_builtin_tensor %__auto.blk.16.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %228 = torch_c.from_builtin_tensor %__auto.blk.16.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.16.attn_v.weight = util.global.load @__auto.blk.16.attn_v.weight : tensor<1024x4096xf16> - %180 = torch_c.from_builtin_tensor %__auto.blk.16.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %181 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %182 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %229 = torch_c.from_builtin_tensor %__auto.blk.16.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %230 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %231 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %232 = torch.vtensor.literal(dense<16> : tensor) : !torch.vtensor<[],si64> + %233 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %234 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.16.attn_output.weight = util.global.load @__auto.blk.16.attn_output.weight : tensor<4096x4096xf16> - %183 = torch_c.from_builtin_tensor %__auto.blk.16.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %235 = torch_c.from_builtin_tensor %__auto.blk.16.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.16.ffn_norm.weight = util.global.load @__auto.blk.16.ffn_norm.weight : tensor<4096xf32> - %184 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %236 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.16.ffn_gate.weight = util.global.load @__auto.blk.16.ffn_gate.weight : tensor<14336x4096xf16> - %185 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %237 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.16.ffn_up.weight = util.global.load @__auto.blk.16.ffn_up.weight : tensor<14336x4096xf16> - %186 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %238 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.16.ffn_down.weight = util.global.load @__auto.blk.16.ffn_down.weight : tensor<4096x14336xf16> - %187 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %239 = torch_c.from_builtin_tensor %__auto.blk.16.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.17.attn_norm.weight = util.global.load @__auto.blk.17.attn_norm.weight : tensor<4096xf32> - %188 = torch_c.from_builtin_tensor %__auto.blk.17.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %240 = torch_c.from_builtin_tensor %__auto.blk.17.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.17.attn_q.weight = util.global.load @__auto.blk.17.attn_q.weight : tensor<4096x4096xf16> - %189 = torch_c.from_builtin_tensor %__auto.blk.17.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %241 = torch_c.from_builtin_tensor %__auto.blk.17.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.17.attn_k.weight = util.global.load @__auto.blk.17.attn_k.weight : tensor<1024x4096xf16> - %190 = torch_c.from_builtin_tensor %__auto.blk.17.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %242 = torch_c.from_builtin_tensor %__auto.blk.17.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.17.attn_v.weight = util.global.load @__auto.blk.17.attn_v.weight : tensor<1024x4096xf16> - %191 = torch_c.from_builtin_tensor %__auto.blk.17.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %192 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %193 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %243 = torch_c.from_builtin_tensor %__auto.blk.17.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %244 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %245 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %246 = torch.vtensor.literal(dense<17> : tensor) : !torch.vtensor<[],si64> + %247 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %248 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.17.attn_output.weight = util.global.load @__auto.blk.17.attn_output.weight : tensor<4096x4096xf16> - %194 = torch_c.from_builtin_tensor %__auto.blk.17.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %249 = torch_c.from_builtin_tensor %__auto.blk.17.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.17.ffn_norm.weight = util.global.load @__auto.blk.17.ffn_norm.weight : tensor<4096xf32> - %195 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %250 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.17.ffn_gate.weight = util.global.load @__auto.blk.17.ffn_gate.weight : tensor<14336x4096xf16> - %196 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %251 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.17.ffn_up.weight = util.global.load @__auto.blk.17.ffn_up.weight : tensor<14336x4096xf16> - %197 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %252 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.17.ffn_down.weight = util.global.load @__auto.blk.17.ffn_down.weight : tensor<4096x14336xf16> - %198 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %253 = torch_c.from_builtin_tensor %__auto.blk.17.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.18.attn_norm.weight = util.global.load @__auto.blk.18.attn_norm.weight : tensor<4096xf32> - %199 = torch_c.from_builtin_tensor %__auto.blk.18.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %254 = torch_c.from_builtin_tensor %__auto.blk.18.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.18.attn_q.weight = util.global.load @__auto.blk.18.attn_q.weight : tensor<4096x4096xf16> - %200 = torch_c.from_builtin_tensor %__auto.blk.18.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %255 = torch_c.from_builtin_tensor %__auto.blk.18.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.18.attn_k.weight = util.global.load @__auto.blk.18.attn_k.weight : tensor<1024x4096xf16> - %201 = torch_c.from_builtin_tensor %__auto.blk.18.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %256 = torch_c.from_builtin_tensor %__auto.blk.18.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.18.attn_v.weight = util.global.load @__auto.blk.18.attn_v.weight : tensor<1024x4096xf16> - %202 = torch_c.from_builtin_tensor %__auto.blk.18.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %203 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %204 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %257 = torch_c.from_builtin_tensor %__auto.blk.18.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %258 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %259 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %260 = torch.vtensor.literal(dense<18> : tensor) : !torch.vtensor<[],si64> + %261 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %262 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.18.attn_output.weight = util.global.load @__auto.blk.18.attn_output.weight : tensor<4096x4096xf16> - %205 = torch_c.from_builtin_tensor %__auto.blk.18.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %263 = torch_c.from_builtin_tensor %__auto.blk.18.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.18.ffn_norm.weight = util.global.load @__auto.blk.18.ffn_norm.weight : tensor<4096xf32> - %206 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %264 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.18.ffn_gate.weight = util.global.load @__auto.blk.18.ffn_gate.weight : tensor<14336x4096xf16> - %207 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %265 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.18.ffn_up.weight = util.global.load @__auto.blk.18.ffn_up.weight : tensor<14336x4096xf16> - %208 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %266 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.18.ffn_down.weight = util.global.load @__auto.blk.18.ffn_down.weight : tensor<4096x14336xf16> - %209 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %267 = torch_c.from_builtin_tensor %__auto.blk.18.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.19.attn_norm.weight = util.global.load @__auto.blk.19.attn_norm.weight : tensor<4096xf32> - %210 = torch_c.from_builtin_tensor %__auto.blk.19.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %268 = torch_c.from_builtin_tensor %__auto.blk.19.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.19.attn_q.weight = util.global.load @__auto.blk.19.attn_q.weight : tensor<4096x4096xf16> - %211 = torch_c.from_builtin_tensor %__auto.blk.19.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %269 = torch_c.from_builtin_tensor %__auto.blk.19.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.19.attn_k.weight = util.global.load @__auto.blk.19.attn_k.weight : tensor<1024x4096xf16> - %212 = torch_c.from_builtin_tensor %__auto.blk.19.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %270 = torch_c.from_builtin_tensor %__auto.blk.19.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.19.attn_v.weight = util.global.load @__auto.blk.19.attn_v.weight : tensor<1024x4096xf16> - %213 = torch_c.from_builtin_tensor %__auto.blk.19.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %214 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %215 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %271 = torch_c.from_builtin_tensor %__auto.blk.19.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %272 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %273 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %274 = torch.vtensor.literal(dense<19> : tensor) : !torch.vtensor<[],si64> + %275 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %276 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.19.attn_output.weight = util.global.load @__auto.blk.19.attn_output.weight : tensor<4096x4096xf16> - %216 = torch_c.from_builtin_tensor %__auto.blk.19.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %277 = torch_c.from_builtin_tensor %__auto.blk.19.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.19.ffn_norm.weight = util.global.load @__auto.blk.19.ffn_norm.weight : tensor<4096xf32> - %217 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %278 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.19.ffn_gate.weight = util.global.load @__auto.blk.19.ffn_gate.weight : tensor<14336x4096xf16> - %218 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %279 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.19.ffn_up.weight = util.global.load @__auto.blk.19.ffn_up.weight : tensor<14336x4096xf16> - %219 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %280 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.19.ffn_down.weight = util.global.load @__auto.blk.19.ffn_down.weight : tensor<4096x14336xf16> - %220 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %281 = torch_c.from_builtin_tensor %__auto.blk.19.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.20.attn_norm.weight = util.global.load @__auto.blk.20.attn_norm.weight : tensor<4096xf32> - %221 = torch_c.from_builtin_tensor %__auto.blk.20.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %282 = torch_c.from_builtin_tensor %__auto.blk.20.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.20.attn_q.weight = util.global.load @__auto.blk.20.attn_q.weight : tensor<4096x4096xf16> - %222 = torch_c.from_builtin_tensor %__auto.blk.20.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %283 = torch_c.from_builtin_tensor %__auto.blk.20.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.20.attn_k.weight = util.global.load @__auto.blk.20.attn_k.weight : tensor<1024x4096xf16> - %223 = torch_c.from_builtin_tensor %__auto.blk.20.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %284 = torch_c.from_builtin_tensor %__auto.blk.20.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.20.attn_v.weight = util.global.load @__auto.blk.20.attn_v.weight : tensor<1024x4096xf16> - %224 = torch_c.from_builtin_tensor %__auto.blk.20.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %225 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %226 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %285 = torch_c.from_builtin_tensor %__auto.blk.20.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %286 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %287 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %288 = torch.vtensor.literal(dense<20> : tensor) : !torch.vtensor<[],si64> + %289 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %290 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.20.attn_output.weight = util.global.load @__auto.blk.20.attn_output.weight : tensor<4096x4096xf16> - %227 = torch_c.from_builtin_tensor %__auto.blk.20.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %291 = torch_c.from_builtin_tensor %__auto.blk.20.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.20.ffn_norm.weight = util.global.load @__auto.blk.20.ffn_norm.weight : tensor<4096xf32> - %228 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %292 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.20.ffn_gate.weight = util.global.load @__auto.blk.20.ffn_gate.weight : tensor<14336x4096xf16> - %229 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %293 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.20.ffn_up.weight = util.global.load @__auto.blk.20.ffn_up.weight : tensor<14336x4096xf16> - %230 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %294 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.20.ffn_down.weight = util.global.load @__auto.blk.20.ffn_down.weight : tensor<4096x14336xf16> - %231 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %295 = torch_c.from_builtin_tensor %__auto.blk.20.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.21.attn_norm.weight = util.global.load @__auto.blk.21.attn_norm.weight : tensor<4096xf32> - %232 = torch_c.from_builtin_tensor %__auto.blk.21.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %296 = torch_c.from_builtin_tensor %__auto.blk.21.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.21.attn_q.weight = util.global.load @__auto.blk.21.attn_q.weight : tensor<4096x4096xf16> - %233 = torch_c.from_builtin_tensor %__auto.blk.21.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %297 = torch_c.from_builtin_tensor %__auto.blk.21.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.21.attn_k.weight = util.global.load @__auto.blk.21.attn_k.weight : tensor<1024x4096xf16> - %234 = torch_c.from_builtin_tensor %__auto.blk.21.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %298 = torch_c.from_builtin_tensor %__auto.blk.21.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.21.attn_v.weight = util.global.load @__auto.blk.21.attn_v.weight : tensor<1024x4096xf16> - %235 = torch_c.from_builtin_tensor %__auto.blk.21.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %236 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %237 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %299 = torch_c.from_builtin_tensor %__auto.blk.21.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %300 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %301 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %302 = torch.vtensor.literal(dense<21> : tensor) : !torch.vtensor<[],si64> + %303 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %304 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.21.attn_output.weight = util.global.load @__auto.blk.21.attn_output.weight : tensor<4096x4096xf16> - %238 = torch_c.from_builtin_tensor %__auto.blk.21.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %305 = torch_c.from_builtin_tensor %__auto.blk.21.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.21.ffn_norm.weight = util.global.load @__auto.blk.21.ffn_norm.weight : tensor<4096xf32> - %239 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %306 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.21.ffn_gate.weight = util.global.load @__auto.blk.21.ffn_gate.weight : tensor<14336x4096xf16> - %240 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %307 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.21.ffn_up.weight = util.global.load @__auto.blk.21.ffn_up.weight : tensor<14336x4096xf16> - %241 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %308 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.21.ffn_down.weight = util.global.load @__auto.blk.21.ffn_down.weight : tensor<4096x14336xf16> - %242 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %309 = torch_c.from_builtin_tensor %__auto.blk.21.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.22.attn_norm.weight = util.global.load @__auto.blk.22.attn_norm.weight : tensor<4096xf32> - %243 = torch_c.from_builtin_tensor %__auto.blk.22.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %310 = torch_c.from_builtin_tensor %__auto.blk.22.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.22.attn_q.weight = util.global.load @__auto.blk.22.attn_q.weight : tensor<4096x4096xf16> - %244 = torch_c.from_builtin_tensor %__auto.blk.22.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %311 = torch_c.from_builtin_tensor %__auto.blk.22.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.22.attn_k.weight = util.global.load @__auto.blk.22.attn_k.weight : tensor<1024x4096xf16> - %245 = torch_c.from_builtin_tensor %__auto.blk.22.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %312 = torch_c.from_builtin_tensor %__auto.blk.22.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.22.attn_v.weight = util.global.load @__auto.blk.22.attn_v.weight : tensor<1024x4096xf16> - %246 = torch_c.from_builtin_tensor %__auto.blk.22.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %247 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %248 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %313 = torch_c.from_builtin_tensor %__auto.blk.22.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %314 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %315 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %316 = torch.vtensor.literal(dense<22> : tensor) : !torch.vtensor<[],si64> + %317 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %318 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.22.attn_output.weight = util.global.load @__auto.blk.22.attn_output.weight : tensor<4096x4096xf16> - %249 = torch_c.from_builtin_tensor %__auto.blk.22.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %319 = torch_c.from_builtin_tensor %__auto.blk.22.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.22.ffn_norm.weight = util.global.load @__auto.blk.22.ffn_norm.weight : tensor<4096xf32> - %250 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %320 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.22.ffn_gate.weight = util.global.load @__auto.blk.22.ffn_gate.weight : tensor<14336x4096xf16> - %251 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %321 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.22.ffn_up.weight = util.global.load @__auto.blk.22.ffn_up.weight : tensor<14336x4096xf16> - %252 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %322 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.22.ffn_down.weight = util.global.load @__auto.blk.22.ffn_down.weight : tensor<4096x14336xf16> - %253 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %323 = torch_c.from_builtin_tensor %__auto.blk.22.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.23.attn_norm.weight = util.global.load @__auto.blk.23.attn_norm.weight : tensor<4096xf32> - %254 = torch_c.from_builtin_tensor %__auto.blk.23.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %324 = torch_c.from_builtin_tensor %__auto.blk.23.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.23.attn_q.weight = util.global.load @__auto.blk.23.attn_q.weight : tensor<4096x4096xf16> - %255 = torch_c.from_builtin_tensor %__auto.blk.23.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %325 = torch_c.from_builtin_tensor %__auto.blk.23.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.23.attn_k.weight = util.global.load @__auto.blk.23.attn_k.weight : tensor<1024x4096xf16> - %256 = torch_c.from_builtin_tensor %__auto.blk.23.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %326 = torch_c.from_builtin_tensor %__auto.blk.23.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.23.attn_v.weight = util.global.load @__auto.blk.23.attn_v.weight : tensor<1024x4096xf16> - %257 = torch_c.from_builtin_tensor %__auto.blk.23.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %258 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %259 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %327 = torch_c.from_builtin_tensor %__auto.blk.23.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %328 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %329 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %330 = torch.vtensor.literal(dense<23> : tensor) : !torch.vtensor<[],si64> + %331 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %332 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.23.attn_output.weight = util.global.load @__auto.blk.23.attn_output.weight : tensor<4096x4096xf16> - %260 = torch_c.from_builtin_tensor %__auto.blk.23.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %333 = torch_c.from_builtin_tensor %__auto.blk.23.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.23.ffn_norm.weight = util.global.load @__auto.blk.23.ffn_norm.weight : tensor<4096xf32> - %261 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %334 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.23.ffn_gate.weight = util.global.load @__auto.blk.23.ffn_gate.weight : tensor<14336x4096xf16> - %262 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %335 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.23.ffn_up.weight = util.global.load @__auto.blk.23.ffn_up.weight : tensor<14336x4096xf16> - %263 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %336 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.23.ffn_down.weight = util.global.load @__auto.blk.23.ffn_down.weight : tensor<4096x14336xf16> - %264 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %337 = torch_c.from_builtin_tensor %__auto.blk.23.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.24.attn_norm.weight = util.global.load @__auto.blk.24.attn_norm.weight : tensor<4096xf32> - %265 = torch_c.from_builtin_tensor %__auto.blk.24.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %338 = torch_c.from_builtin_tensor %__auto.blk.24.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.24.attn_q.weight = util.global.load @__auto.blk.24.attn_q.weight : tensor<4096x4096xf16> - %266 = torch_c.from_builtin_tensor %__auto.blk.24.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %339 = torch_c.from_builtin_tensor %__auto.blk.24.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.24.attn_k.weight = util.global.load @__auto.blk.24.attn_k.weight : tensor<1024x4096xf16> - %267 = torch_c.from_builtin_tensor %__auto.blk.24.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %340 = torch_c.from_builtin_tensor %__auto.blk.24.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.24.attn_v.weight = util.global.load @__auto.blk.24.attn_v.weight : tensor<1024x4096xf16> - %268 = torch_c.from_builtin_tensor %__auto.blk.24.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %269 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %270 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %341 = torch_c.from_builtin_tensor %__auto.blk.24.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %342 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %343 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %344 = torch.vtensor.literal(dense<24> : tensor) : !torch.vtensor<[],si64> + %345 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %346 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.24.attn_output.weight = util.global.load @__auto.blk.24.attn_output.weight : tensor<4096x4096xf16> - %271 = torch_c.from_builtin_tensor %__auto.blk.24.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %347 = torch_c.from_builtin_tensor %__auto.blk.24.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.24.ffn_norm.weight = util.global.load @__auto.blk.24.ffn_norm.weight : tensor<4096xf32> - %272 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %348 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.24.ffn_gate.weight = util.global.load @__auto.blk.24.ffn_gate.weight : tensor<14336x4096xf16> - %273 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %349 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.24.ffn_up.weight = util.global.load @__auto.blk.24.ffn_up.weight : tensor<14336x4096xf16> - %274 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %350 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.24.ffn_down.weight = util.global.load @__auto.blk.24.ffn_down.weight : tensor<4096x14336xf16> - %275 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %351 = torch_c.from_builtin_tensor %__auto.blk.24.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.25.attn_norm.weight = util.global.load @__auto.blk.25.attn_norm.weight : tensor<4096xf32> - %276 = torch_c.from_builtin_tensor %__auto.blk.25.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %352 = torch_c.from_builtin_tensor %__auto.blk.25.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.25.attn_q.weight = util.global.load @__auto.blk.25.attn_q.weight : tensor<4096x4096xf16> - %277 = torch_c.from_builtin_tensor %__auto.blk.25.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %353 = torch_c.from_builtin_tensor %__auto.blk.25.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.25.attn_k.weight = util.global.load @__auto.blk.25.attn_k.weight : tensor<1024x4096xf16> - %278 = torch_c.from_builtin_tensor %__auto.blk.25.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %354 = torch_c.from_builtin_tensor %__auto.blk.25.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.25.attn_v.weight = util.global.load @__auto.blk.25.attn_v.weight : tensor<1024x4096xf16> - %279 = torch_c.from_builtin_tensor %__auto.blk.25.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %280 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %281 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %355 = torch_c.from_builtin_tensor %__auto.blk.25.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %356 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %357 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %358 = torch.vtensor.literal(dense<25> : tensor) : !torch.vtensor<[],si64> + %359 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %360 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.25.attn_output.weight = util.global.load @__auto.blk.25.attn_output.weight : tensor<4096x4096xf16> - %282 = torch_c.from_builtin_tensor %__auto.blk.25.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %361 = torch_c.from_builtin_tensor %__auto.blk.25.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.25.ffn_norm.weight = util.global.load @__auto.blk.25.ffn_norm.weight : tensor<4096xf32> - %283 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %362 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.25.ffn_gate.weight = util.global.load @__auto.blk.25.ffn_gate.weight : tensor<14336x4096xf16> - %284 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %363 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.25.ffn_up.weight = util.global.load @__auto.blk.25.ffn_up.weight : tensor<14336x4096xf16> - %285 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %364 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.25.ffn_down.weight = util.global.load @__auto.blk.25.ffn_down.weight : tensor<4096x14336xf16> - %286 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %365 = torch_c.from_builtin_tensor %__auto.blk.25.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.26.attn_norm.weight = util.global.load @__auto.blk.26.attn_norm.weight : tensor<4096xf32> - %287 = torch_c.from_builtin_tensor %__auto.blk.26.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %366 = torch_c.from_builtin_tensor %__auto.blk.26.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.26.attn_q.weight = util.global.load @__auto.blk.26.attn_q.weight : tensor<4096x4096xf16> - %288 = torch_c.from_builtin_tensor %__auto.blk.26.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %367 = torch_c.from_builtin_tensor %__auto.blk.26.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.26.attn_k.weight = util.global.load @__auto.blk.26.attn_k.weight : tensor<1024x4096xf16> - %289 = torch_c.from_builtin_tensor %__auto.blk.26.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %368 = torch_c.from_builtin_tensor %__auto.blk.26.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.26.attn_v.weight = util.global.load @__auto.blk.26.attn_v.weight : tensor<1024x4096xf16> - %290 = torch_c.from_builtin_tensor %__auto.blk.26.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %291 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %292 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %369 = torch_c.from_builtin_tensor %__auto.blk.26.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %370 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %371 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %372 = torch.vtensor.literal(dense<26> : tensor) : !torch.vtensor<[],si64> + %373 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %374 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.26.attn_output.weight = util.global.load @__auto.blk.26.attn_output.weight : tensor<4096x4096xf16> - %293 = torch_c.from_builtin_tensor %__auto.blk.26.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %375 = torch_c.from_builtin_tensor %__auto.blk.26.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.26.ffn_norm.weight = util.global.load @__auto.blk.26.ffn_norm.weight : tensor<4096xf32> - %294 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %376 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.26.ffn_gate.weight = util.global.load @__auto.blk.26.ffn_gate.weight : tensor<14336x4096xf16> - %295 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %377 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.26.ffn_up.weight = util.global.load @__auto.blk.26.ffn_up.weight : tensor<14336x4096xf16> - %296 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %378 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.26.ffn_down.weight = util.global.load @__auto.blk.26.ffn_down.weight : tensor<4096x14336xf16> - %297 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %379 = torch_c.from_builtin_tensor %__auto.blk.26.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.27.attn_norm.weight = util.global.load @__auto.blk.27.attn_norm.weight : tensor<4096xf32> - %298 = torch_c.from_builtin_tensor %__auto.blk.27.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %380 = torch_c.from_builtin_tensor %__auto.blk.27.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.27.attn_q.weight = util.global.load @__auto.blk.27.attn_q.weight : tensor<4096x4096xf16> - %299 = torch_c.from_builtin_tensor %__auto.blk.27.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %381 = torch_c.from_builtin_tensor %__auto.blk.27.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.27.attn_k.weight = util.global.load @__auto.blk.27.attn_k.weight : tensor<1024x4096xf16> - %300 = torch_c.from_builtin_tensor %__auto.blk.27.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %382 = torch_c.from_builtin_tensor %__auto.blk.27.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.27.attn_v.weight = util.global.load @__auto.blk.27.attn_v.weight : tensor<1024x4096xf16> - %301 = torch_c.from_builtin_tensor %__auto.blk.27.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %302 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %303 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %383 = torch_c.from_builtin_tensor %__auto.blk.27.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %384 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %385 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %386 = torch.vtensor.literal(dense<27> : tensor) : !torch.vtensor<[],si64> + %387 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %388 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.27.attn_output.weight = util.global.load @__auto.blk.27.attn_output.weight : tensor<4096x4096xf16> - %304 = torch_c.from_builtin_tensor %__auto.blk.27.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %389 = torch_c.from_builtin_tensor %__auto.blk.27.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.27.ffn_norm.weight = util.global.load @__auto.blk.27.ffn_norm.weight : tensor<4096xf32> - %305 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %390 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.27.ffn_gate.weight = util.global.load @__auto.blk.27.ffn_gate.weight : tensor<14336x4096xf16> - %306 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %391 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.27.ffn_up.weight = util.global.load @__auto.blk.27.ffn_up.weight : tensor<14336x4096xf16> - %307 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %392 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.27.ffn_down.weight = util.global.load @__auto.blk.27.ffn_down.weight : tensor<4096x14336xf16> - %308 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %393 = torch_c.from_builtin_tensor %__auto.blk.27.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.28.attn_norm.weight = util.global.load @__auto.blk.28.attn_norm.weight : tensor<4096xf32> - %309 = torch_c.from_builtin_tensor %__auto.blk.28.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %394 = torch_c.from_builtin_tensor %__auto.blk.28.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.28.attn_q.weight = util.global.load @__auto.blk.28.attn_q.weight : tensor<4096x4096xf16> - %310 = torch_c.from_builtin_tensor %__auto.blk.28.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %395 = torch_c.from_builtin_tensor %__auto.blk.28.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.28.attn_k.weight = util.global.load @__auto.blk.28.attn_k.weight : tensor<1024x4096xf16> - %311 = torch_c.from_builtin_tensor %__auto.blk.28.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %396 = torch_c.from_builtin_tensor %__auto.blk.28.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.28.attn_v.weight = util.global.load @__auto.blk.28.attn_v.weight : tensor<1024x4096xf16> - %312 = torch_c.from_builtin_tensor %__auto.blk.28.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %313 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %314 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %397 = torch_c.from_builtin_tensor %__auto.blk.28.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %398 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %399 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %400 = torch.vtensor.literal(dense<28> : tensor) : !torch.vtensor<[],si64> + %401 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %402 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.28.attn_output.weight = util.global.load @__auto.blk.28.attn_output.weight : tensor<4096x4096xf16> - %315 = torch_c.from_builtin_tensor %__auto.blk.28.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %403 = torch_c.from_builtin_tensor %__auto.blk.28.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.28.ffn_norm.weight = util.global.load @__auto.blk.28.ffn_norm.weight : tensor<4096xf32> - %316 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %404 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.28.ffn_gate.weight = util.global.load @__auto.blk.28.ffn_gate.weight : tensor<14336x4096xf16> - %317 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %405 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.28.ffn_up.weight = util.global.load @__auto.blk.28.ffn_up.weight : tensor<14336x4096xf16> - %318 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %406 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.28.ffn_down.weight = util.global.load @__auto.blk.28.ffn_down.weight : tensor<4096x14336xf16> - %319 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %407 = torch_c.from_builtin_tensor %__auto.blk.28.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.29.attn_norm.weight = util.global.load @__auto.blk.29.attn_norm.weight : tensor<4096xf32> - %320 = torch_c.from_builtin_tensor %__auto.blk.29.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %408 = torch_c.from_builtin_tensor %__auto.blk.29.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.29.attn_q.weight = util.global.load @__auto.blk.29.attn_q.weight : tensor<4096x4096xf16> - %321 = torch_c.from_builtin_tensor %__auto.blk.29.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %409 = torch_c.from_builtin_tensor %__auto.blk.29.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.29.attn_k.weight = util.global.load @__auto.blk.29.attn_k.weight : tensor<1024x4096xf16> - %322 = torch_c.from_builtin_tensor %__auto.blk.29.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %410 = torch_c.from_builtin_tensor %__auto.blk.29.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.29.attn_v.weight = util.global.load @__auto.blk.29.attn_v.weight : tensor<1024x4096xf16> - %323 = torch_c.from_builtin_tensor %__auto.blk.29.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %324 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %325 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %411 = torch_c.from_builtin_tensor %__auto.blk.29.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %412 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %413 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %414 = torch.vtensor.literal(dense<29> : tensor) : !torch.vtensor<[],si64> + %415 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %416 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.29.attn_output.weight = util.global.load @__auto.blk.29.attn_output.weight : tensor<4096x4096xf16> - %326 = torch_c.from_builtin_tensor %__auto.blk.29.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %417 = torch_c.from_builtin_tensor %__auto.blk.29.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.29.ffn_norm.weight = util.global.load @__auto.blk.29.ffn_norm.weight : tensor<4096xf32> - %327 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %418 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.29.ffn_gate.weight = util.global.load @__auto.blk.29.ffn_gate.weight : tensor<14336x4096xf16> - %328 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %419 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.29.ffn_up.weight = util.global.load @__auto.blk.29.ffn_up.weight : tensor<14336x4096xf16> - %329 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %420 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.29.ffn_down.weight = util.global.load @__auto.blk.29.ffn_down.weight : tensor<4096x14336xf16> - %330 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %421 = torch_c.from_builtin_tensor %__auto.blk.29.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.30.attn_norm.weight = util.global.load @__auto.blk.30.attn_norm.weight : tensor<4096xf32> - %331 = torch_c.from_builtin_tensor %__auto.blk.30.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %422 = torch_c.from_builtin_tensor %__auto.blk.30.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.30.attn_q.weight = util.global.load @__auto.blk.30.attn_q.weight : tensor<4096x4096xf16> - %332 = torch_c.from_builtin_tensor %__auto.blk.30.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %423 = torch_c.from_builtin_tensor %__auto.blk.30.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.30.attn_k.weight = util.global.load @__auto.blk.30.attn_k.weight : tensor<1024x4096xf16> - %333 = torch_c.from_builtin_tensor %__auto.blk.30.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %424 = torch_c.from_builtin_tensor %__auto.blk.30.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.30.attn_v.weight = util.global.load @__auto.blk.30.attn_v.weight : tensor<1024x4096xf16> - %334 = torch_c.from_builtin_tensor %__auto.blk.30.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %335 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %336 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %425 = torch_c.from_builtin_tensor %__auto.blk.30.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %426 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %427 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %428 = torch.vtensor.literal(dense<30> : tensor) : !torch.vtensor<[],si64> + %429 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %430 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.30.attn_output.weight = util.global.load @__auto.blk.30.attn_output.weight : tensor<4096x4096xf16> - %337 = torch_c.from_builtin_tensor %__auto.blk.30.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %431 = torch_c.from_builtin_tensor %__auto.blk.30.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.30.ffn_norm.weight = util.global.load @__auto.blk.30.ffn_norm.weight : tensor<4096xf32> - %338 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %432 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.30.ffn_gate.weight = util.global.load @__auto.blk.30.ffn_gate.weight : tensor<14336x4096xf16> - %339 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %433 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.30.ffn_up.weight = util.global.load @__auto.blk.30.ffn_up.weight : tensor<14336x4096xf16> - %340 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %434 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.30.ffn_down.weight = util.global.load @__auto.blk.30.ffn_down.weight : tensor<4096x14336xf16> - %341 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %435 = torch_c.from_builtin_tensor %__auto.blk.30.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.blk.31.attn_norm.weight = util.global.load @__auto.blk.31.attn_norm.weight : tensor<4096xf32> - %342 = torch_c.from_builtin_tensor %__auto.blk.31.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %436 = torch_c.from_builtin_tensor %__auto.blk.31.attn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.31.attn_q.weight = util.global.load @__auto.blk.31.attn_q.weight : tensor<4096x4096xf16> - %343 = torch_c.from_builtin_tensor %__auto.blk.31.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %437 = torch_c.from_builtin_tensor %__auto.blk.31.attn_q.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.31.attn_k.weight = util.global.load @__auto.blk.31.attn_k.weight : tensor<1024x4096xf16> - %344 = torch_c.from_builtin_tensor %__auto.blk.31.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %438 = torch_c.from_builtin_tensor %__auto.blk.31.attn_k.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> %__auto.blk.31.attn_v.weight = util.global.load @__auto.blk.31.attn_v.weight : tensor<1024x4096xf16> - %345 = torch_c.from_builtin_tensor %__auto.blk.31.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> - %346 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %347 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %439 = torch_c.from_builtin_tensor %__auto.blk.31.attn_v.weight : tensor<1024x4096xf16> -> !torch.vtensor<[1024,4096],f16> + %440 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %441 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %442 = torch.vtensor.literal(dense<31> : tensor) : !torch.vtensor<[],si64> + %443 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %444 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %__auto.blk.31.attn_output.weight = util.global.load @__auto.blk.31.attn_output.weight : tensor<4096x4096xf16> - %348 = torch_c.from_builtin_tensor %__auto.blk.31.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> + %445 = torch_c.from_builtin_tensor %__auto.blk.31.attn_output.weight : tensor<4096x4096xf16> -> !torch.vtensor<[4096,4096],f16> %__auto.blk.31.ffn_norm.weight = util.global.load @__auto.blk.31.ffn_norm.weight : tensor<4096xf32> - %349 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %446 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.blk.31.ffn_gate.weight = util.global.load @__auto.blk.31.ffn_gate.weight : tensor<14336x4096xf16> - %350 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %447 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_gate.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.31.ffn_up.weight = util.global.load @__auto.blk.31.ffn_up.weight : tensor<14336x4096xf16> - %351 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> + %448 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_up.weight : tensor<14336x4096xf16> -> !torch.vtensor<[14336,4096],f16> %__auto.blk.31.ffn_down.weight = util.global.load @__auto.blk.31.ffn_down.weight : tensor<4096x14336xf16> - %352 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> + %449 = torch_c.from_builtin_tensor %__auto.blk.31.ffn_down.weight : tensor<4096x14336xf16> -> !torch.vtensor<[4096,14336],f16> %__auto.output_norm.weight = util.global.load @__auto.output_norm.weight : tensor<4096xf32> - %353 = torch_c.from_builtin_tensor %__auto.output_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> + %450 = torch_c.from_builtin_tensor %__auto.output_norm.weight : tensor<4096xf32> -> !torch.vtensor<[4096],f32> %__auto.output.weight = util.global.load @__auto.output.weight : tensor<128256x4096xf16> - %354 = torch_c.from_builtin_tensor %__auto.output.weight : tensor<128256x4096xf16> -> !torch.vtensor<[128256,4096],f16> - %355 = torch.copy.to_vtensor %arg4 : !torch.vtensor<[?,2097152],f16> - %356 = torch.symbolic_int "s0" {min_val = 2, max_val = 4095} : !torch.int - %357 = torch.symbolic_int "s1" {min_val = 2, max_val = 9223372036854775806} : !torch.int - torch.bind_symbolic_shape %arg3, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %355, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %451 = torch_c.from_builtin_tensor %__auto.output.weight : tensor<128256x4096xf16> -> !torch.vtensor<[128256,4096],f16> + %452 = torch.copy.to_vtensor %arg4 : !torch.vtensor<[?,2097152],f16> + %453 = torch.symbolic_int "s0" {min_val = 2, max_val = 4095} : !torch.int + %454 = torch.symbolic_int "s1" {min_val = 0, max_val = 9223372036854775807} : !torch.int + torch.bind_symbolic_shape %arg3, [%453], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + torch.bind_symbolic_shape %452, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> %int1 = torch.constant.int 1 - %358 = torch.aten.size.int %arg3, %int1 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.int - %int32 = torch.constant.int 32 - %359 = torch.aten.mul.int %358, %int32 : !torch.int, !torch.int -> !torch.int + %455 = torch.aten.size.int %arg3, %int1 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.int %int0 = torch.constant.int 0 - %int1_0 = torch.constant.int 1 + %456 = torch.aten.size.int %452, %int0 : !torch.vtensor<[?,2097152],f16>, !torch.int -> !torch.int + %int32 = torch.constant.int 32 + %457 = torch.aten.mul.int %455, %int32 : !torch.int, !torch.int -> !torch.int + %int0_0 = torch.constant.int 0 + %int1_1 = torch.constant.int 1 %none = torch.constant.none - %none_1 = torch.constant.none + %none_2 = torch.constant.none %cpu = torch.constant.device "cpu" %false = torch.constant.bool false - %360 = torch.aten.arange.start_step %int0, %359, %int1_0, %none, %none_1, %cpu, %false : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %360, [%356], affine_map<()[s0] -> (s0 * 32)> : !torch.vtensor<[?],si64> + %458 = torch.aten.arange.start_step %int0_0, %457, %int1_1, %none, %none_2, %cpu, %false : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[?],si64> + torch.bind_symbolic_shape %458, [%453], affine_map<()[s0] -> (s0 * 32)> : !torch.vtensor<[?],si64> %int-1 = torch.constant.int -1 - %361 = torch.aten.unsqueeze %arg1, %int-1 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %362 = torch.aten.ge.Tensor %360, %361 : !torch.vtensor<[?],si64>, !torch.vtensor<[4,1],si64> -> !torch.vtensor<[4,?],i1> - torch.bind_symbolic_shape %362, [%356], affine_map<()[s0] -> (4, s0 * 32)> : !torch.vtensor<[4,?],i1> - %int0_2 = torch.constant.int 0 - %int6 = torch.constant.int 6 - %int0_3 = torch.constant.int 0 - %cpu_4 = torch.constant.device "cpu" - %none_5 = torch.constant.none - %363 = torch.aten.scalar_tensor %int0_2, %int6, %int0_3, %cpu_4, %none_5 : !torch.int, !torch.int, !torch.int, !torch.Device, !torch.none -> !torch.vtensor<[],f32> - %float-Inf = torch.constant.float 0xFFF0000000000000 - %int6_6 = torch.constant.int 6 - %int0_7 = torch.constant.int 0 - %cpu_8 = torch.constant.device "cpu" - %none_9 = torch.constant.none - %364 = torch.aten.scalar_tensor %float-Inf, %int6_6, %int0_7, %cpu_8, %none_9 : !torch.float, !torch.int, !torch.int, !torch.Device, !torch.none -> !torch.vtensor<[],f32> - %365 = torch.aten.where.self %362, %364, %363 : !torch.vtensor<[4,?],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[4,?],f32> - torch.bind_symbolic_shape %365, [%356], affine_map<()[s0] -> (4, s0 * 32)> : !torch.vtensor<[4,?],f32> + %459 = torch.aten.unsqueeze %arg1, %int-1 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %460 = torch.aten.ge.Tensor %458, %459 : !torch.vtensor<[?],si64>, !torch.vtensor<[4,1],si64> -> !torch.vtensor<[4,?],i1> + torch.bind_symbolic_shape %460, [%453], affine_map<()[s0] -> (4, s0 * 32)> : !torch.vtensor<[4,?],i1> + %none_3 = torch.constant.none + %461 = torch.aten.clone %0, %none_3 : !torch.vtensor<[],f16>, !torch.none -> !torch.vtensor<[],f16> + %462 = torch.aten.detach %461 : !torch.vtensor<[],f16> -> !torch.vtensor<[],f16> + %463 = torch.aten.detach %462 : !torch.vtensor<[],f16> -> !torch.vtensor<[],f16> + %464 = torch.aten.detach %463 : !torch.vtensor<[],f16> -> !torch.vtensor<[],f16> + %int0_4 = torch.constant.int 0 %int5 = torch.constant.int 5 - %366 = torch.prims.convert_element_type %365, %int5 : !torch.vtensor<[4,?],f32>, !torch.int -> !torch.vtensor<[4,?],f16> - torch.bind_symbolic_shape %366, [%356], affine_map<()[s0] -> (4, s0 * 32)> : !torch.vtensor<[4,?],f16> + %int0_5 = torch.constant.int 0 + %cpu_6 = torch.constant.device "cpu" + %none_7 = torch.constant.none + %465 = torch.aten.scalar_tensor %int0_4, %int5, %int0_5, %cpu_6, %none_7 : !torch.int, !torch.int, !torch.int, !torch.Device, !torch.none -> !torch.vtensor<[],f16> + %466 = torch.aten.where.self %460, %464, %465 : !torch.vtensor<[4,?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[],f16> -> !torch.vtensor<[4,?],f16> + torch.bind_symbolic_shape %466, [%453], affine_map<()[s0] -> (4, s0 * 32)> : !torch.vtensor<[4,?],f16> + %int5_8 = torch.constant.int 5 + %467 = torch.prims.convert_element_type %466, %int5_8 : !torch.vtensor<[4,?],f16>, !torch.int -> !torch.vtensor<[4,?],f16> + torch.bind_symbolic_shape %467, [%453], affine_map<()[s0] -> (4, s0 * 32)> : !torch.vtensor<[4,?],f16> + %int1_9 = torch.constant.int 1 + %468 = torch.aten.unsqueeze %467, %int1_9 : !torch.vtensor<[4,?],f16>, !torch.int -> !torch.vtensor<[4,1,?],f16> + torch.bind_symbolic_shape %468, [%453], affine_map<()[s0] -> (4, 1, s0 * 32)> : !torch.vtensor<[4,1,?],f16> %int1_10 = torch.constant.int 1 - %367 = torch.aten.unsqueeze %366, %int1_10 : !torch.vtensor<[4,?],f16>, !torch.int -> !torch.vtensor<[4,1,?],f16> - torch.bind_symbolic_shape %367, [%356], affine_map<()[s0] -> (4, 1, s0 * 32)> : !torch.vtensor<[4,1,?],f16> - %int1_11 = torch.constant.int 1 - %368 = torch.aten.unsqueeze %367, %int1_11 : !torch.vtensor<[4,1,?],f16>, !torch.int -> !torch.vtensor<[4,1,1,?],f16> - torch.bind_symbolic_shape %368, [%356], affine_map<()[s0] -> (4, 1, 1, s0 * 32)> : !torch.vtensor<[4,1,1,?],f16> + %469 = torch.aten.unsqueeze %468, %int1_10 : !torch.vtensor<[4,1,?],f16>, !torch.int -> !torch.vtensor<[4,1,1,?],f16> + torch.bind_symbolic_shape %469, [%453], affine_map<()[s0] -> (4, 1, 1, s0 * 32)> : !torch.vtensor<[4,1,1,?],f16> + %int5_11 = torch.constant.int 5 + %470 = torch.prims.convert_element_type %469, %int5_11 : !torch.vtensor<[4,1,1,?],f16>, !torch.int -> !torch.vtensor<[4,1,1,?],f16> + torch.bind_symbolic_shape %470, [%453], affine_map<()[s0] -> (4, 1, 1, s0 * 32)> : !torch.vtensor<[4,1,1,?],f16> %int0_12 = torch.constant.int 0 %int1_13 = torch.constant.int 1 %none_14 = torch.constant.none %none_15 = torch.constant.none %cpu_16 = torch.constant.device "cpu" %false_17 = torch.constant.bool false - %369 = torch.aten.arange.start %int0_12, %int1_13, %none_14, %none_15, %cpu_16, %false_17 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1],si64> + %471 = torch.aten.arange.start %int0_12, %int1_13, %none_14, %none_15, %cpu_16, %false_17 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1],si64> %int0_18 = torch.constant.int 0 - %370 = torch.aten.unsqueeze %369, %int0_18 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> + %472 = torch.aten.unsqueeze %471, %int0_18 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> %int1_19 = torch.constant.int 1 - %371 = torch.aten.unsqueeze %arg2, %int1_19 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %473 = torch.aten.unsqueeze %arg2, %int1_19 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> %int1_20 = torch.constant.int 1 - %372 = torch.aten.add.Tensor %370, %371, %int1_20 : !torch.vtensor<[1,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %474 = torch.aten.add.Tensor %472, %473, %int1_20 : !torch.vtensor<[1,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> %int131072 = torch.constant.int 131072 %none_21 = torch.constant.none %none_22 = torch.constant.none %cpu_23 = torch.constant.device "cpu" %false_24 = torch.constant.bool false - %373 = torch.aten.arange %int131072, %none_21, %none_22, %cpu_23, %false_24 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> + %475 = torch.aten.arange %int131072, %none_21, %none_22, %cpu_23, %false_24 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[131072],si64> %int0_25 = torch.constant.int 0 %int128 = torch.constant.int 128 - %none_26 = torch.constant.none - %none_27 = torch.constant.none - %cpu_28 = torch.constant.device "cpu" - %false_29 = torch.constant.bool false - %374 = torch.aten.arange.start %int0_25, %int128, %none_26, %none_27, %cpu_28, %false_29 : !torch.int, !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[128],si64> %int2 = torch.constant.int 2 - %375 = torch.aten.floor_divide.Scalar %374, %int2 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],si64> - %int6_30 = torch.constant.int 6 - %376 = torch.prims.convert_element_type %375, %int6_30 : !torch.vtensor<[128],si64>, !torch.int -> !torch.vtensor<[128],f32> - %int128_31 = torch.constant.int 128 - %377 = torch.aten.div.Scalar %376, %int128_31 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> - %float2.000000e00 = torch.constant.float 2.000000e+00 - %378 = torch.aten.mul.Scalar %377, %float2.000000e00 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %int4 = torch.constant.int 4 + %none_26 = torch.constant.none + %cpu_27 = torch.constant.device "cpu" + %false_28 = torch.constant.bool false + %476 = torch.aten.arange.start_step %int0_25, %int128, %int2, %int4, %none_26, %cpu_27, %false_28 : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[64],si64> + %int6 = torch.constant.int 6 + %477 = torch.prims.convert_element_type %476, %int6 : !torch.vtensor<[64],si64>, !torch.int -> !torch.vtensor<[64],f32> + %int128_29 = torch.constant.int 128 + %478 = torch.aten.div.Scalar %477, %int128_29 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %float5.000000e05 = torch.constant.float 5.000000e+05 - %379 = torch.aten.pow.Scalar %float5.000000e05, %378 : !torch.float, !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> - %380 = torch.aten.reciprocal %379 : !torch.vtensor<[128],f32> -> !torch.vtensor<[128],f32> + %479 = torch.aten.pow.Scalar %float5.000000e05, %478 : !torch.float, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %480 = torch.aten.reciprocal %479 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> %float1.000000e00 = torch.constant.float 1.000000e+00 - %381 = torch.aten.mul.Scalar %380, %float1.000000e00 : !torch.vtensor<[128],f32>, !torch.float -> !torch.vtensor<[128],f32> + %481 = torch.aten.mul.Scalar %480, %float1.000000e00 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %482 = torch.aten.reciprocal %481 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %float6.283190e00 = torch.constant.float 6.2831853071795862 + %483 = torch.aten.mul.Scalar %482, %float6.283190e00 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],f32> + %float8.192000e03 = torch.constant.float 8.192000e+03 + %484 = torch.aten.gt.Scalar %483, %float8.192000e03 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %int8 = torch.constant.int 8 + %485 = torch.aten.div.Scalar %481, %int8 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %486 = torch.aten.where.self %484, %485, %481 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %487 = torch.aten.reciprocal %483 : !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8192 = torch.constant.int 8192 + %488 = torch.aten.mul.Scalar %487, %int8192 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %int1_30 = torch.constant.int 1 + %int1_31 = torch.constant.int 1 + %489 = torch.aten.sub.Scalar %488, %int1_30, %int1_31 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %int3 = torch.constant.int 3 + %490 = torch.aten.div.Scalar %489, %int3 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> %int1_32 = torch.constant.int 1 - %382 = torch.aten.unsqueeze %373, %int1_32 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> - %int0_33 = torch.constant.int 0 - %383 = torch.aten.unsqueeze %381, %int0_33 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> - %384 = torch.aten.mul.Tensor %382, %383 : !torch.vtensor<[131072,1],si64>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> - %int4 = torch.constant.int 4 - %385 = torch.prim.ListConstruct %int4 : (!torch.int) -> !torch.list - %386 = torch.aten.view %372, %385 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4],si64> - %387 = torch.prim.ListConstruct %386 : (!torch.vtensor<[4],si64>) -> !torch.list> - %388 = torch.aten.index.Tensor %384, %387 : !torch.vtensor<[131072,128],f32>, !torch.list> -> !torch.vtensor<[4,128],f32> - %int1_34 = torch.constant.int 1 - %389 = torch.aten.unsqueeze %388, %int1_34 : !torch.vtensor<[4,128],f32>, !torch.int -> !torch.vtensor<[4,1,128],f32> - %int-1_35 = torch.constant.int -1 - %false_36 = torch.constant.bool false - %false_37 = torch.constant.bool false - %390 = torch.aten.embedding %0, %arg0, %int-1_35, %false_36, %false_37 : !torch.vtensor<[128256,4096],f16>, !torch.vtensor<[4,1],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,1,4096],f16> + %int1_33 = torch.constant.int 1 + %491 = torch.aten.rsub.Scalar %490, %int1_32, %int1_33 : !torch.vtensor<[64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64],f32> + %492 = torch.aten.mul.Tensor %491, %486 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int8_34 = torch.constant.int 8 + %493 = torch.aten.div.Scalar %492, %int8_34 : !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %494 = torch.aten.mul.Tensor %490, %486 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %int1_35 = torch.constant.int 1 + %495 = torch.aten.add.Tensor %493, %494, %int1_35 : !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.int -> !torch.vtensor<[64],f32> + %float2.048000e03 = torch.constant.float 2.048000e+03 + %496 = torch.aten.lt.Scalar %483, %float2.048000e03 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %497 = torch.aten.bitwise_not %496 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %float8.192000e03_36 = torch.constant.float 8.192000e+03 + %498 = torch.aten.gt.Scalar %483, %float8.192000e03_36 : !torch.vtensor<[64],f32>, !torch.float -> !torch.vtensor<[64],i1> + %499 = torch.aten.bitwise_not %498 : !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %500 = torch.aten.mul.Tensor %497, %499 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],i1> -> !torch.vtensor<[64],i1> + %501 = torch.aten.where.self %500, %495, %486 : !torch.vtensor<[64],i1>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[64],f32> + %502 = torch.prim.ListConstruct %501, %501 : (!torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.list + %int-1_37 = torch.constant.int -1 + %503 = torch.aten.cat %502, %int-1_37 : !torch.list, !torch.int -> !torch.vtensor<[128],f32> %int6_38 = torch.constant.int 6 - %391 = torch.prims.convert_element_type %390, %int6_38 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_39 = torch.constant.int 2 - %392 = torch.aten.pow.Tensor_Scalar %391, %int2_39 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_40 = torch.constant.int -1 - %393 = torch.prim.ListConstruct %int-1_40 : (!torch.int) -> !torch.list - %true = torch.constant.bool true - %none_41 = torch.constant.none - %394 = torch.aten.mean.dim %392, %393, %true, %none_41 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06 = torch.constant.float 9.9999997473787516E-6 - %int1_42 = torch.constant.int 1 - %395 = torch.aten.add.Scalar %394, %float9.999990e-06, %int1_42 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %396 = torch.aten.rsqrt %395 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %397 = torch.aten.mul.Tensor %391, %396 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %504 = torch.prims.convert_element_type %503, %int6_38 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[128],f32> + %int1_39 = torch.constant.int 1 + %505 = torch.aten.unsqueeze %475, %int1_39 : !torch.vtensor<[131072],si64>, !torch.int -> !torch.vtensor<[131072,1],si64> + %int6_40 = torch.constant.int 6 + %506 = torch.prims.convert_element_type %505, %int6_40 : !torch.vtensor<[131072,1],si64>, !torch.int -> !torch.vtensor<[131072,1],f32> + %int0_41 = torch.constant.int 0 + %507 = torch.aten.unsqueeze %504, %int0_41 : !torch.vtensor<[128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %int6_42 = torch.constant.int 6 + %508 = torch.prims.convert_element_type %507, %int6_42 : !torch.vtensor<[1,128],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + %509 = torch.aten.mul.Tensor %506, %508 : !torch.vtensor<[131072,1],f32>, !torch.vtensor<[1,128],f32> -> !torch.vtensor<[131072,128],f32> + %510 = torch.aten.cos %509 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> %int5_43 = torch.constant.int 5 - %398 = torch.prims.convert_element_type %397, %int5_43 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %399 = torch.aten.mul.Tensor %1, %398 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %511 = torch.prims.convert_element_type %510, %int5_43 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %512 = torch.aten.sin %509 : !torch.vtensor<[131072,128],f32> -> !torch.vtensor<[131072,128],f32> %int5_44 = torch.constant.int 5 - %400 = torch.prims.convert_element_type %399, %int5_44 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %513 = torch.prims.convert_element_type %512, %int5_44 : !torch.vtensor<[131072,128],f32>, !torch.int -> !torch.vtensor<[131072,128],f16> + %int4_45 = torch.constant.int 4 + %514 = torch.prim.ListConstruct %int4_45 : (!torch.int) -> !torch.list + %515 = torch.aten.view %474, %514 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4],si64> + %int1_46 = torch.constant.int 1 + %int0_47 = torch.constant.int 0 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %int1_48 = torch.constant.int 1 + %516 = torch.aten.slice.Tensor %511, %int1_46, %int0_47, %int9223372036854775807, %int1_48 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[131072,128],f16> + %517 = torch.prim.ListConstruct %515 : (!torch.vtensor<[4],si64>) -> !torch.list> + %518 = torch.aten.index.Tensor %516, %517 : !torch.vtensor<[131072,128],f16>, !torch.list> -> !torch.vtensor<[4,128],f16> + %int4_49 = torch.constant.int 4 + %519 = torch.prim.ListConstruct %int4_49 : (!torch.int) -> !torch.list + %520 = torch.aten.view %474, %519 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4],si64> + %int1_50 = torch.constant.int 1 + %int0_51 = torch.constant.int 0 + %int9223372036854775807_52 = torch.constant.int 9223372036854775807 + %int1_53 = torch.constant.int 1 + %521 = torch.aten.slice.Tensor %513, %int1_50, %int0_51, %int9223372036854775807_52, %int1_53 : !torch.vtensor<[131072,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[131072,128],f16> + %522 = torch.prim.ListConstruct %520 : (!torch.vtensor<[4],si64>) -> !torch.list> + %523 = torch.aten.index.Tensor %521, %522 : !torch.vtensor<[131072,128],f16>, !torch.list> -> !torch.vtensor<[4,128],f16> + %int0_54 = torch.constant.int 0 + %int0_55 = torch.constant.int 0 + %int9223372036854775807_56 = torch.constant.int 9223372036854775807 + %int1_57 = torch.constant.int 1 + %524 = torch.aten.slice.Tensor %518, %int0_54, %int0_55, %int9223372036854775807_56, %int1_57 : !torch.vtensor<[4,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,128],f16> + %int1_58 = torch.constant.int 1 + %525 = torch.aten.unsqueeze %524, %int1_58 : !torch.vtensor<[4,128],f16>, !torch.int -> !torch.vtensor<[4,1,128],f16> + %int2_59 = torch.constant.int 2 + %526 = torch.aten.unsqueeze %525, %int2_59 : !torch.vtensor<[4,1,128],f16>, !torch.int -> !torch.vtensor<[4,1,1,128],f16> + %int3_60 = torch.constant.int 3 + %int0_61 = torch.constant.int 0 + %int9223372036854775807_62 = torch.constant.int 9223372036854775807 + %int1_63 = torch.constant.int 1 + %527 = torch.aten.slice.Tensor %526, %int3_60, %int0_61, %int9223372036854775807_62, %int1_63 : !torch.vtensor<[4,1,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1,1,128],f16> + %int0_64 = torch.constant.int 0 + %int0_65 = torch.constant.int 0 + %int9223372036854775807_66 = torch.constant.int 9223372036854775807 + %int1_67 = torch.constant.int 1 + %528 = torch.aten.slice.Tensor %523, %int0_64, %int0_65, %int9223372036854775807_66, %int1_67 : !torch.vtensor<[4,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,128],f16> + %int1_68 = torch.constant.int 1 + %529 = torch.aten.unsqueeze %528, %int1_68 : !torch.vtensor<[4,128],f16>, !torch.int -> !torch.vtensor<[4,1,128],f16> + %int2_69 = torch.constant.int 2 + %530 = torch.aten.unsqueeze %529, %int2_69 : !torch.vtensor<[4,1,128],f16>, !torch.int -> !torch.vtensor<[4,1,1,128],f16> + %int3_70 = torch.constant.int 3 + %int0_71 = torch.constant.int 0 + %int9223372036854775807_72 = torch.constant.int 9223372036854775807 + %int1_73 = torch.constant.int 1 + %531 = torch.aten.slice.Tensor %530, %int3_70, %int0_71, %int9223372036854775807_72, %int1_73 : !torch.vtensor<[4,1,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1,1,128],f16> + %int5_74 = torch.constant.int 5 + %532 = torch.prims.convert_element_type %1, %int5_74 : !torch.vtensor<[128256,4096],f16>, !torch.int -> !torch.vtensor<[128256,4096],f16> + %int-1_75 = torch.constant.int -1 + %false_76 = torch.constant.bool false + %false_77 = torch.constant.bool false + %533 = torch.aten.embedding %532, %arg0, %int-1_75, %false_76, %false_77 : !torch.vtensor<[128256,4096],f16>, !torch.vtensor<[4,1],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,1,4096],f16> + %int6_78 = torch.constant.int 6 + %534 = torch.prims.convert_element_type %533, %int6_78 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_79 = torch.constant.int 2 + %535 = torch.aten.pow.Tensor_Scalar %534, %int2_79 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_80 = torch.constant.int -1 + %536 = torch.prim.ListConstruct %int-1_80 : (!torch.int) -> !torch.list + %true = torch.constant.bool true + %none_81 = torch.constant.none + %537 = torch.aten.mean.dim %535, %536, %true, %none_81 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06 = torch.constant.float 9.9999997473787516E-6 + %int1_82 = torch.constant.int 1 + %538 = torch.aten.add.Scalar %537, %float9.999990e-06, %int1_82 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %539 = torch.aten.rsqrt %538 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %540 = torch.aten.mul.Tensor %534, %539 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_83 = torch.constant.int 5 + %541 = torch.prims.convert_element_type %540, %int5_83 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %542 = torch.aten.mul.Tensor %2, %541 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_84 = torch.constant.int 5 + %543 = torch.prims.convert_element_type %542, %int5_84 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> %int-2 = torch.constant.int -2 - %int-1_45 = torch.constant.int -1 - %401 = torch.aten.transpose.int %2, %int-2, %int-1_45 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_46 = torch.constant.int 4 + %int-1_85 = torch.constant.int -1 + %544 = torch.aten.transpose.int %3, %int-2, %int-1_85 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_86 = torch.constant.int 5 + %545 = torch.prims.convert_element_type %544, %int5_86 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_87 = torch.constant.int 4 %int4096 = torch.constant.int 4096 - %402 = torch.prim.ListConstruct %int4_46, %int4096 : (!torch.int, !torch.int) -> !torch.list - %403 = torch.aten.view %400, %402 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %404 = torch.aten.mm %403, %401 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_47 = torch.constant.int 4 - %int1_48 = torch.constant.int 1 - %int4096_49 = torch.constant.int 4096 - %405 = torch.prim.ListConstruct %int4_47, %int1_48, %int4096_49 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %406 = torch.aten.view %404, %405 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_50 = torch.constant.int -2 - %int-1_51 = torch.constant.int -1 - %407 = torch.aten.transpose.int %3, %int-2_50, %int-1_51 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_52 = torch.constant.int 4 - %int4096_53 = torch.constant.int 4096 - %408 = torch.prim.ListConstruct %int4_52, %int4096_53 : (!torch.int, !torch.int) -> !torch.list - %409 = torch.aten.view %400, %408 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %410 = torch.aten.mm %409, %407 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_54 = torch.constant.int 4 - %int1_55 = torch.constant.int 1 + %546 = torch.prim.ListConstruct %int4_87, %int4096 : (!torch.int, !torch.int) -> !torch.list + %547 = torch.aten.view %543, %546 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %548 = torch.aten.mm %547, %545 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_88 = torch.constant.int 4 + %int1_89 = torch.constant.int 1 + %int4096_90 = torch.constant.int 4096 + %549 = torch.prim.ListConstruct %int4_88, %int1_89, %int4096_90 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %550 = torch.aten.view %548, %549 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_91 = torch.constant.int -2 + %int-1_92 = torch.constant.int -1 + %551 = torch.aten.transpose.int %4, %int-2_91, %int-1_92 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_93 = torch.constant.int 5 + %552 = torch.prims.convert_element_type %551, %int5_93 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_94 = torch.constant.int 4 + %int4096_95 = torch.constant.int 4096 + %553 = torch.prim.ListConstruct %int4_94, %int4096_95 : (!torch.int, !torch.int) -> !torch.list + %554 = torch.aten.view %543, %553 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %555 = torch.aten.mm %554, %552 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_96 = torch.constant.int 4 + %int1_97 = torch.constant.int 1 %int1024 = torch.constant.int 1024 - %411 = torch.prim.ListConstruct %int4_54, %int1_55, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %412 = torch.aten.view %410, %411 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_56 = torch.constant.int -2 - %int-1_57 = torch.constant.int -1 - %413 = torch.aten.transpose.int %4, %int-2_56, %int-1_57 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_58 = torch.constant.int 4 - %int4096_59 = torch.constant.int 4096 - %414 = torch.prim.ListConstruct %int4_58, %int4096_59 : (!torch.int, !torch.int) -> !torch.list - %415 = torch.aten.view %400, %414 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %416 = torch.aten.mm %415, %413 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_60 = torch.constant.int 4 - %int1_61 = torch.constant.int 1 - %int1024_62 = torch.constant.int 1024 - %417 = torch.prim.ListConstruct %int4_60, %int1_61, %int1024_62 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %418 = torch.aten.view %416, %417 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_63 = torch.constant.int 4 - %int1_64 = torch.constant.int 1 - %int32_65 = torch.constant.int 32 - %int128_66 = torch.constant.int 128 - %419 = torch.prim.ListConstruct %int4_63, %int1_64, %int32_65, %int128_66 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %420 = torch.aten.view %406, %419 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_67 = torch.constant.int 4 - %int1_68 = torch.constant.int 1 - %int8 = torch.constant.int 8 - %int128_69 = torch.constant.int 128 - %421 = torch.prim.ListConstruct %int4_67, %int1_68, %int8, %int128_69 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %422 = torch.aten.view %412, %421 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_70 = torch.constant.int 4 - %int1_71 = torch.constant.int 1 - %int8_72 = torch.constant.int 8 - %int128_73 = torch.constant.int 128 - %423 = torch.prim.ListConstruct %int4_70, %int1_71, %int8_72, %int128_73 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %424 = torch.aten.view %418, %423 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_74 = torch.constant.int 6 - %425 = torch.prims.convert_element_type %420, %int6_74 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %426 = torch_c.to_builtin_tensor %425 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %427 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %428 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%426, %427) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %429 = torch_c.from_builtin_tensor %428 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_75 = torch.constant.int 5 - %430 = torch.prims.convert_element_type %429, %int5_75 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_76 = torch.constant.int 6 - %431 = torch.prims.convert_element_type %422, %int6_76 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %432 = torch_c.to_builtin_tensor %431 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %433 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %434 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%432, %433) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %435 = torch_c.from_builtin_tensor %434 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_77 = torch.constant.int 5 - %436 = torch.prims.convert_element_type %435, %int5_77 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int0_78 = torch.constant.int 0 - %437 = torch.aten.size.int %355, %int0_78 : !torch.vtensor<[?,2097152],f16>, !torch.int -> !torch.int - %int32_79 = torch.constant.int 32 - %int2_80 = torch.constant.int 2 - %int32_81 = torch.constant.int 32 - %int8_82 = torch.constant.int 8 - %int128_83 = torch.constant.int 128 - %438 = torch.prim.ListConstruct %437, %int32_79, %int2_80, %int32_81, %int8_82, %int128_83 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %439 = torch.aten.view %355, %438 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %439, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_84 = torch.constant.int 32 - %440 = torch.aten.mul.int %437, %int32_84 : !torch.int, !torch.int -> !torch.int - %int2_85 = torch.constant.int 2 - %441 = torch.aten.mul.int %440, %int2_85 : !torch.int, !torch.int -> !torch.int - %int32_86 = torch.constant.int 32 - %442 = torch.aten.mul.int %441, %int32_86 : !torch.int, !torch.int -> !torch.int - %int8_87 = torch.constant.int 8 - %int128_88 = torch.constant.int 128 - %443 = torch.prim.ListConstruct %442, %int8_87, %int128_88 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %444 = torch.aten.view %439, %443 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %444, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_89 = torch.constant.int 32 - %445 = torch.aten.floor_divide.Scalar %arg2, %int32_89 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_90 = torch.constant.int 1 - %446 = torch.aten.unsqueeze %445, %int1_90 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_91 = torch.constant.int 1 - %false_92 = torch.constant.bool false - %447 = torch.aten.gather %arg3, %int1_91, %446, %false_92 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_93 = torch.constant.int 32 - %448 = torch.aten.remainder.Scalar %arg2, %int32_93 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_94 = torch.constant.int 1 - %449 = torch.aten.unsqueeze %448, %int1_94 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_95 = torch.constant.none - %450 = torch.aten.clone %5, %none_95 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_96 = torch.constant.int 0 - %451 = torch.aten.unsqueeze %450, %int0_96 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_97 = torch.constant.int 4 - %int1_98 = torch.constant.int 1 - %452 = torch.prim.ListConstruct %int4_97, %int1_98 : (!torch.int, !torch.int) -> !torch.list - %int1_99 = torch.constant.int 1 - %int1_100 = torch.constant.int 1 - %453 = torch.prim.ListConstruct %int1_99, %int1_100 : (!torch.int, !torch.int) -> !torch.list + %556 = torch.prim.ListConstruct %int4_96, %int1_97, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %557 = torch.aten.view %555, %556 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_98 = torch.constant.int -2 + %int-1_99 = torch.constant.int -1 + %558 = torch.aten.transpose.int %5, %int-2_98, %int-1_99 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_100 = torch.constant.int 5 + %559 = torch.prims.convert_element_type %558, %int5_100 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> %int4_101 = torch.constant.int 4 - %int0_102 = torch.constant.int 0 - %cpu_103 = torch.constant.device "cpu" - %false_104 = torch.constant.bool false - %454 = torch.aten.empty_strided %452, %453, %int4_101, %int0_102, %cpu_103, %false_104 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int0_105 = torch.constant.int 0 - %455 = torch.aten.fill.Scalar %454, %int0_105 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int4096_102 = torch.constant.int 4096 + %560 = torch.prim.ListConstruct %int4_101, %int4096_102 : (!torch.int, !torch.int) -> !torch.list + %561 = torch.aten.view %543, %560 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %562 = torch.aten.mm %561, %559 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_103 = torch.constant.int 4 + %int1_104 = torch.constant.int 1 + %int1024_105 = torch.constant.int 1024 + %563 = torch.prim.ListConstruct %int4_103, %int1_104, %int1024_105 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %564 = torch.aten.view %562, %563 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> %int4_106 = torch.constant.int 4 %int1_107 = torch.constant.int 1 - %456 = torch.prim.ListConstruct %int4_106, %int1_107 : (!torch.int, !torch.int) -> !torch.list - %457 = torch.aten.repeat %451, %456 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> %int32_108 = torch.constant.int 32 - %458 = torch.aten.mul.Scalar %447, %int32_108 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_109 = torch.constant.int 1 - %459 = torch.aten.add.Tensor %458, %455, %int1_109 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_110 = torch.constant.int 2 - %460 = torch.aten.mul.Scalar %459, %int2_110 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int128_109 = torch.constant.int 128 + %565 = torch.prim.ListConstruct %int4_106, %int1_107, %int32_108, %int128_109 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %566 = torch.aten.view %550, %565 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_110 = torch.constant.int 4 %int1_111 = torch.constant.int 1 - %461 = torch.aten.add.Tensor %460, %457, %int1_111 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_112 = torch.constant.int 32 - %462 = torch.aten.mul.Scalar %461, %int32_112 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_113 = torch.constant.int 1 - %463 = torch.aten.add.Tensor %462, %449, %int1_113 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %464 = torch.prim.ListConstruct %463 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_114 = torch.constant.bool false - %465 = torch.aten.index_put %444, %464, %436, %false_114 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %465, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_115 = torch.constant.int 32 - %int2_116 = torch.constant.int 2 - %int32_117 = torch.constant.int 32 - %int8_118 = torch.constant.int 8 - %int128_119 = torch.constant.int 128 - %466 = torch.prim.ListConstruct %437, %int32_115, %int2_116, %int32_117, %int8_118, %int128_119 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %467 = torch.aten.view %465, %466 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %467, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152 = torch.constant.int 2097152 - %468 = torch.prim.ListConstruct %437, %int2097152 : (!torch.int, !torch.int) -> !torch.list - %469 = torch.aten.view %467, %468 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %469, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_120 = torch.constant.int 32 - %int2_121 = torch.constant.int 2 - %int32_122 = torch.constant.int 32 - %int8_123 = torch.constant.int 8 - %int128_124 = torch.constant.int 128 - %470 = torch.prim.ListConstruct %437, %int32_120, %int2_121, %int32_122, %int8_123, %int128_124 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %471 = torch.aten.view %469, %470 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %471, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_125 = torch.constant.int 8 - %int128_126 = torch.constant.int 128 - %472 = torch.prim.ListConstruct %442, %int8_125, %int128_126 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %473 = torch.aten.view %471, %472 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %473, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_127 = torch.constant.int 32 - %474 = torch.aten.floor_divide.Scalar %arg2, %int32_127 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int8_112 = torch.constant.int 8 + %int128_113 = torch.constant.int 128 + %567 = torch.prim.ListConstruct %int4_110, %int1_111, %int8_112, %int128_113 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %568 = torch.aten.view %557, %567 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_114 = torch.constant.int 4 + %int1_115 = torch.constant.int 1 + %int8_116 = torch.constant.int 8 + %int128_117 = torch.constant.int 128 + %569 = torch.prim.ListConstruct %int4_114, %int1_115, %int8_116, %int128_117 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %570 = torch.aten.view %564, %569 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_118 = torch.constant.int 1 + %int2_119 = torch.constant.int 2 + %571 = torch.aten.transpose.int %566, %int1_118, %int2_119 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %572 = torch.aten.mul.Tensor %571, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_120 = torch.constant.int 3 + %int0_121 = torch.constant.int 0 + %int64 = torch.constant.int 64 + %int1_122 = torch.constant.int 1 + %573 = torch.aten.slice.Tensor %571, %int3_120, %int0_121, %int64, %int1_122 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_123 = torch.constant.int 3 + %int64_124 = torch.constant.int 64 + %int9223372036854775807_125 = torch.constant.int 9223372036854775807 + %int1_126 = torch.constant.int 1 + %574 = torch.aten.slice.Tensor %571, %int3_123, %int64_124, %int9223372036854775807_125, %int1_126 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %575 = torch.aten.neg %574 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %576 = torch.prim.ListConstruct %575, %573 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_127 = torch.constant.int -1 + %577 = torch.aten.cat %576, %int-1_127 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %578 = torch.aten.mul.Tensor %577, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> %int1_128 = torch.constant.int 1 - %475 = torch.aten.unsqueeze %474, %int1_128 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %579 = torch.aten.add.Tensor %572, %578, %int1_128 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_129 = torch.constant.int 1 - %false_130 = torch.constant.bool false - %476 = torch.aten.gather %arg3, %int1_129, %475, %false_130 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_131 = torch.constant.int 32 - %477 = torch.aten.remainder.Scalar %arg2, %int32_131 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_132 = torch.constant.int 1 - %478 = torch.aten.unsqueeze %477, %int1_132 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_133 = torch.constant.none - %479 = torch.aten.clone %6, %none_133 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %int2_130 = torch.constant.int 2 + %580 = torch.aten.transpose.int %579, %int1_129, %int2_130 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_131 = torch.constant.int 1 + %int2_132 = torch.constant.int 2 + %581 = torch.aten.transpose.int %568, %int1_131, %int2_132 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %582 = torch.aten.mul.Tensor %581, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_133 = torch.constant.int 3 %int0_134 = torch.constant.int 0 - %480 = torch.aten.unsqueeze %479, %int0_134 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_135 = torch.constant.int 4 + %int64_135 = torch.constant.int 64 %int1_136 = torch.constant.int 1 - %481 = torch.prim.ListConstruct %int4_135, %int1_136 : (!torch.int, !torch.int) -> !torch.list - %int1_137 = torch.constant.int 1 - %int1_138 = torch.constant.int 1 - %482 = torch.prim.ListConstruct %int1_137, %int1_138 : (!torch.int, !torch.int) -> !torch.list - %int4_139 = torch.constant.int 4 - %int0_140 = torch.constant.int 0 - %cpu_141 = torch.constant.device "cpu" - %false_142 = torch.constant.bool false - %483 = torch.aten.empty_strided %481, %482, %int4_139, %int0_140, %cpu_141, %false_142 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int0_143 = torch.constant.int 0 - %484 = torch.aten.fill.Scalar %483, %int0_143 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_144 = torch.constant.int 4 - %int1_145 = torch.constant.int 1 - %485 = torch.prim.ListConstruct %int4_144, %int1_145 : (!torch.int, !torch.int) -> !torch.list - %486 = torch.aten.repeat %480, %485 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_146 = torch.constant.int 32 - %487 = torch.aten.mul.Scalar %476, %int32_146 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_147 = torch.constant.int 1 - %488 = torch.aten.add.Tensor %487, %484, %int1_147 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_148 = torch.constant.int 2 - %489 = torch.aten.mul.Scalar %488, %int2_148 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_149 = torch.constant.int 1 - %490 = torch.aten.add.Tensor %489, %486, %int1_149 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %583 = torch.aten.slice.Tensor %581, %int3_133, %int0_134, %int64_135, %int1_136 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_137 = torch.constant.int 3 + %int64_138 = torch.constant.int 64 + %int9223372036854775807_139 = torch.constant.int 9223372036854775807 + %int1_140 = torch.constant.int 1 + %584 = torch.aten.slice.Tensor %581, %int3_137, %int64_138, %int9223372036854775807_139, %int1_140 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %585 = torch.aten.neg %584 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %586 = torch.prim.ListConstruct %585, %583 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_141 = torch.constant.int -1 + %587 = torch.aten.cat %586, %int-1_141 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %588 = torch.aten.mul.Tensor %587, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_142 = torch.constant.int 1 + %589 = torch.aten.add.Tensor %582, %588, %int1_142 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_143 = torch.constant.int 1 + %int2_144 = torch.constant.int 2 + %590 = torch.aten.transpose.int %589, %int1_143, %int2_144 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_145 = torch.constant.int 32 + %int2_146 = torch.constant.int 2 + %int8_147 = torch.constant.int 8 + %int32_148 = torch.constant.int 32 + %int128_149 = torch.constant.int 128 + %591 = torch.prim.ListConstruct %456, %int32_145, %int2_146, %int8_147, %int32_148, %int128_149 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %592 = torch.aten.view %452, %591 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %592, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> %int32_150 = torch.constant.int 32 - %491 = torch.aten.mul.Scalar %490, %int32_150 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_151 = torch.constant.int 1 - %492 = torch.aten.add.Tensor %491, %478, %int1_151 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %493 = torch.prim.ListConstruct %492 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_152 = torch.constant.bool false - %494 = torch.aten.index_put %473, %493, %424, %false_152 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %494, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> + %593 = torch.aten.mul.int %456, %int32_150 : !torch.int, !torch.int -> !torch.int + %int2_151 = torch.constant.int 2 + %594 = torch.aten.mul.int %593, %int2_151 : !torch.int, !torch.int -> !torch.int + %int8_152 = torch.constant.int 8 + %595 = torch.aten.mul.int %594, %int8_152 : !torch.int, !torch.int -> !torch.int %int32_153 = torch.constant.int 32 - %int2_154 = torch.constant.int 2 + %596 = torch.aten.mul.int %595, %int32_153 : !torch.int, !torch.int -> !torch.int + %int128_154 = torch.constant.int 128 + %597 = torch.prim.ListConstruct %596, %int128_154 : (!torch.int, !torch.int) -> !torch.list + %598 = torch.aten.view %592, %597 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %598, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> %int32_155 = torch.constant.int 32 - %int8_156 = torch.constant.int 8 - %int128_157 = torch.constant.int 128 - %495 = torch.prim.ListConstruct %437, %int32_153, %int2_154, %int32_155, %int8_156, %int128_157 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %496 = torch.aten.view %494, %495 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %496, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_158 = torch.constant.int 2097152 - %497 = torch.prim.ListConstruct %437, %int2097152_158 : (!torch.int, !torch.int) -> !torch.list - %498 = torch.aten.view %496, %497 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %498, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %599 = torch.aten.floor_divide.Scalar %arg2, %int32_155 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_156 = torch.constant.int 1 + %600 = torch.aten.unsqueeze %599, %int1_156 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_157 = torch.constant.int 1 + %false_158 = torch.constant.bool false + %601 = torch.aten.gather %arg3, %int1_157, %600, %false_158 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> %int4_159 = torch.constant.int 4 - %499 = torch.prim.ListConstruct %int4_159, %358 : (!torch.int, !torch.int) -> !torch.list %int1_160 = torch.constant.int 1 - %500 = torch.prim.ListConstruct %358, %int1_160 : (!torch.int, !torch.int) -> !torch.list - %int4_161 = torch.constant.int 4 - %int0_162 = torch.constant.int 0 - %cpu_163 = torch.constant.device "cpu" - %false_164 = torch.constant.bool false - %501 = torch.aten.empty_strided %499, %500, %int4_161, %int0_162, %cpu_163, %false_164 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %501, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int0_165 = torch.constant.int 0 - %502 = torch.aten.fill.Scalar %501, %int0_165 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %502, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_166 = torch.constant.int 32 - %503 = torch.aten.mul.Scalar %arg3, %int32_166 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %503, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_167 = torch.constant.int 1 - %504 = torch.aten.add.Tensor %503, %502, %int1_167 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %504, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_168 = torch.constant.int 4 - %505 = torch.aten.mul.int %int4_168, %358 : !torch.int, !torch.int -> !torch.int - %506 = torch.prim.ListConstruct %505 : (!torch.int) -> !torch.list - %507 = torch.aten.view %504, %506 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %507, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_169 = torch.constant.int 32 - %int2_170 = torch.constant.int 2 - %int32_171 = torch.constant.int 32 - %int8_172 = torch.constant.int 8 - %int128_173 = torch.constant.int 128 - %508 = torch.prim.ListConstruct %437, %int32_169, %int2_170, %int32_171, %int8_172, %int128_173 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %509 = torch.aten.view %498, %508 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %509, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_174 = torch.constant.int 32 - %510 = torch.aten.mul.int %437, %int32_174 : !torch.int, !torch.int -> !torch.int - %int2_175 = torch.constant.int 2 - %int32_176 = torch.constant.int 32 - %int8_177 = torch.constant.int 8 - %int128_178 = torch.constant.int 128 - %511 = torch.prim.ListConstruct %510, %int2_175, %int32_176, %int8_177, %int128_178 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %512 = torch.aten.view %509, %511 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %512, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> + %int1_161 = torch.constant.int 1 + %602 = torch.prim.ListConstruct %int4_159, %int1_160, %int1_161 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %603 = torch.aten.view %601, %602 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_162 = torch.constant.int 32 + %604 = torch.aten.remainder.Scalar %arg2, %int32_162 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_163 = torch.constant.int 4 + %int1_164 = torch.constant.int 1 + %int1_165 = torch.constant.int 1 + %605 = torch.prim.ListConstruct %int4_163, %int1_164, %int1_165 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %606 = torch.aten.view %604, %605 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_166 = torch.constant.int 8 + %none_167 = torch.constant.none + %none_168 = torch.constant.none + %cpu_169 = torch.constant.device "cpu" + %false_170 = torch.constant.bool false + %607 = torch.aten.arange %int8_166, %none_167, %none_168, %cpu_169, %false_170 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_171 = torch.constant.int 1 + %int1_172 = torch.constant.int 1 + %int8_173 = torch.constant.int 8 + %608 = torch.prim.ListConstruct %int1_171, %int1_172, %int8_173 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %609 = torch.aten.view %607, %608 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_174 = torch.constant.none + %610 = torch.aten.clone %6, %none_174 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %611 = torch.aten.detach %610 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %612 = torch.aten.detach %611 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %613 = torch.aten.detach %612 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_175 = torch.constant.int 1 + %int1_176 = torch.constant.int 1 + %int1_177 = torch.constant.int 1 + %614 = torch.prim.ListConstruct %int1_175, %int1_176, %int1_177 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %615 = torch.aten.view %613, %614 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_178 = torch.constant.int 32 + %616 = torch.aten.mul.Scalar %603, %int32_178 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int0_179 = torch.constant.int 0 - %513 = torch.aten.index_select %512, %int0_179, %507 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %513, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_180 = torch.constant.int 4 + %int1_180 = torch.constant.int 1 + %617 = torch.aten.add.Scalar %616, %int0_179, %int1_180 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> %int2_181 = torch.constant.int 2 - %int32_182 = torch.constant.int 32 + %618 = torch.aten.mul.Scalar %617, %int2_181 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_182 = torch.constant.int 1 + %619 = torch.aten.add.Tensor %618, %615, %int1_182 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int8_183 = torch.constant.int 8 - %int128_184 = torch.constant.int 128 - %514 = torch.prim.ListConstruct %int4_180, %358, %int2_181, %int32_182, %int8_183, %int128_184 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %515 = torch.aten.view %513, %514 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %515, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_185 = torch.constant.int 0 - %int0_186 = torch.constant.int 0 - %int9223372036854775807 = torch.constant.int 9223372036854775807 - %int1_187 = torch.constant.int 1 - %516 = torch.aten.slice.Tensor %515, %int0_185, %int0_186, %int9223372036854775807, %int1_187 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %516, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_188 = torch.constant.int 1 - %int0_189 = torch.constant.int 0 - %int9223372036854775807_190 = torch.constant.int 9223372036854775807 - %int1_191 = torch.constant.int 1 - %517 = torch.aten.slice.Tensor %516, %int1_188, %int0_189, %int9223372036854775807_190, %int1_191 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %517, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_192 = torch.constant.int 2 - %int0_193 = torch.constant.int 0 - %518 = torch.aten.select.int %517, %int2_192, %int0_193 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %518, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %620 = torch.aten.mul.Scalar %619, %int8_183 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_184 = torch.constant.int 1 + %621 = torch.aten.add.Tensor %620, %609, %int1_184 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_185 = torch.constant.int 32 + %622 = torch.aten.mul.Scalar %621, %int32_185 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_186 = torch.constant.int 1 + %623 = torch.aten.add.Tensor %622, %606, %int1_186 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_187 = torch.constant.int 5 + %624 = torch.prims.convert_element_type %590, %int5_187 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %625 = torch.prim.ListConstruct %623 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_188 = torch.constant.bool false + %626 = torch.aten.index_put %598, %625, %624, %false_188 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %626, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_189 = torch.constant.int 32 + %int2_190 = torch.constant.int 2 + %int8_191 = torch.constant.int 8 + %int32_192 = torch.constant.int 32 + %int128_193 = torch.constant.int 128 + %627 = torch.prim.ListConstruct %456, %int32_189, %int2_190, %int8_191, %int32_192, %int128_193 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %628 = torch.aten.view %626, %627 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %628, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152 = torch.constant.int 2097152 + %629 = torch.prim.ListConstruct %456, %int2097152 : (!torch.int, !torch.int) -> !torch.list + %630 = torch.aten.view %628, %629 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %630, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> %int32_194 = torch.constant.int 32 - %519 = torch.aten.mul.int %358, %int32_194 : !torch.int, !torch.int -> !torch.int %int2_195 = torch.constant.int 2 - %int0_196 = torch.constant.int 0 - %int1_197 = torch.constant.int 1 - %520 = torch.aten.slice.Tensor %518, %int2_195, %int0_196, %519, %int1_197 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %520, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_198 = torch.constant.int 0 - %521 = torch.aten.clone %520, %int0_198 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %521, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_199 = torch.constant.int 1 - %522 = torch.aten.size.int %517, %int1_199 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_200 = torch.constant.int 32 - %523 = torch.aten.mul.int %522, %int32_200 : !torch.int, !torch.int -> !torch.int - %int4_201 = torch.constant.int 4 - %int8_202 = torch.constant.int 8 - %int128_203 = torch.constant.int 128 - %524 = torch.prim.ListConstruct %int4_201, %523, %int8_202, %int128_203 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %525 = torch.aten._unsafe_view %521, %524 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %525, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_204 = torch.constant.int 0 + %int8_196 = torch.constant.int 8 + %int32_197 = torch.constant.int 32 + %int128_198 = torch.constant.int 128 + %631 = torch.prim.ListConstruct %456, %int32_194, %int2_195, %int8_196, %int32_197, %int128_198 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %632 = torch.aten.view %630, %631 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %632, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_199 = torch.constant.int 128 + %633 = torch.prim.ListConstruct %596, %int128_199 : (!torch.int, !torch.int) -> !torch.list + %634 = torch.aten.view %632, %633 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %634, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_200 = torch.constant.none + %635 = torch.aten.clone %7, %none_200 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %636 = torch.aten.detach %635 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %637 = torch.aten.detach %636 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %638 = torch.aten.detach %637 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_201 = torch.constant.int 1 + %int1_202 = torch.constant.int 1 + %int1_203 = torch.constant.int 1 + %639 = torch.prim.ListConstruct %int1_201, %int1_202, %int1_203 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %640 = torch.aten.view %638, %639 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_204 = torch.constant.int 32 + %641 = torch.aten.mul.Scalar %603, %int32_204 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int0_205 = torch.constant.int 0 - %int9223372036854775807_206 = torch.constant.int 9223372036854775807 - %int1_207 = torch.constant.int 1 - %526 = torch.aten.slice.Tensor %525, %int0_204, %int0_205, %int9223372036854775807_206, %int1_207 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %526, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_208 = torch.constant.int 0 - %int0_209 = torch.constant.int 0 - %int9223372036854775807_210 = torch.constant.int 9223372036854775807 - %int1_211 = torch.constant.int 1 - %527 = torch.aten.slice.Tensor %515, %int0_208, %int0_209, %int9223372036854775807_210, %int1_211 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %527, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> + %int1_206 = torch.constant.int 1 + %642 = torch.aten.add.Scalar %641, %int0_205, %int1_206 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_207 = torch.constant.int 2 + %643 = torch.aten.mul.Scalar %642, %int2_207 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_208 = torch.constant.int 1 + %644 = torch.aten.add.Tensor %643, %640, %int1_208 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_209 = torch.constant.int 8 + %645 = torch.aten.mul.Scalar %644, %int8_209 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_210 = torch.constant.int 1 + %646 = torch.aten.add.Tensor %645, %609, %int1_210 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_211 = torch.constant.int 32 + %647 = torch.aten.mul.Scalar %646, %int32_211 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_212 = torch.constant.int 1 - %int0_213 = torch.constant.int 0 - %int9223372036854775807_214 = torch.constant.int 9223372036854775807 - %int1_215 = torch.constant.int 1 - %528 = torch.aten.slice.Tensor %527, %int1_212, %int0_213, %int9223372036854775807_214, %int1_215 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %528, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> + %648 = torch.aten.add.Tensor %647, %606, %int1_212 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_213 = torch.constant.int 5 + %649 = torch.prims.convert_element_type %570, %int5_213 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %650 = torch.prim.ListConstruct %648 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_214 = torch.constant.bool false + %651 = torch.aten.index_put %634, %650, %649, %false_214 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %651, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_215 = torch.constant.int 32 %int2_216 = torch.constant.int 2 - %int1_217 = torch.constant.int 1 - %529 = torch.aten.select.int %528, %int2_216, %int1_217 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %529, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_218 = torch.constant.int 2 - %int0_219 = torch.constant.int 0 - %int1_220 = torch.constant.int 1 - %530 = torch.aten.slice.Tensor %529, %int2_218, %int0_219, %519, %int1_220 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %530, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_221 = torch.constant.int 0 - %531 = torch.aten.clone %530, %int0_221 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %531, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_222 = torch.constant.int 1 - %532 = torch.aten.size.int %528, %int1_222 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_223 = torch.constant.int 32 - %533 = torch.aten.mul.int %532, %int32_223 : !torch.int, !torch.int -> !torch.int - %int4_224 = torch.constant.int 4 - %int8_225 = torch.constant.int 8 - %int128_226 = torch.constant.int 128 - %534 = torch.prim.ListConstruct %int4_224, %533, %int8_225, %int128_226 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %535 = torch.aten._unsafe_view %531, %534 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %535, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_227 = torch.constant.int 0 - %int0_228 = torch.constant.int 0 - %int9223372036854775807_229 = torch.constant.int 9223372036854775807 - %int1_230 = torch.constant.int 1 - %536 = torch.aten.slice.Tensor %535, %int0_227, %int0_228, %int9223372036854775807_229, %int1_230 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %536, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_231 = torch.constant.int -2 - %537 = torch.aten.unsqueeze %526, %int-2_231 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %537, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_232 = torch.constant.int 1 - %538 = torch.aten.size.int %525, %int1_232 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_233 = torch.constant.int 4 - %int8_234 = torch.constant.int 8 + %int8_217 = torch.constant.int 8 + %int32_218 = torch.constant.int 32 + %int128_219 = torch.constant.int 128 + %652 = torch.prim.ListConstruct %456, %int32_215, %int2_216, %int8_217, %int32_218, %int128_219 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %653 = torch.aten.view %651, %652 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %653, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_220 = torch.constant.int 2097152 + %654 = torch.prim.ListConstruct %456, %int2097152_220 : (!torch.int, !torch.int) -> !torch.list + %655 = torch.aten.view %653, %654 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %655, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_221 = torch.constant.none + %656 = torch.aten.clone %8, %none_221 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %657 = torch.aten.detach %656 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %658 = torch.aten.detach %657 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %659 = torch.aten.detach %658 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_222 = torch.constant.none + %660 = torch.aten.clone %9, %none_222 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %661 = torch.aten.detach %660 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %662 = torch.aten.detach %661 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %663 = torch.aten.detach %662 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_223 = torch.constant.none + %664 = torch.aten.clone %10, %none_223 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %665 = torch.aten.detach %664 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %666 = torch.aten.detach %665 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %667 = torch.aten.detach %666 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_224 = torch.constant.int 32 + %int2_225 = torch.constant.int 2 + %int8_226 = torch.constant.int 8 + %int32_227 = torch.constant.int 32 + %int128_228 = torch.constant.int 128 + %668 = torch.prim.ListConstruct %456, %int32_224, %int2_225, %int8_226, %int32_227, %int128_228 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %669 = torch.aten.view %655, %668 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %669, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %670 = torch_c.to_builtin_tensor %669 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %671 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast = tensor.cast %671 : tensor<4x?xi64> to tensor + %672 = torch_c.to_builtin_tensor %659 : !torch.vtensor<[],si64> -> tensor + %673 = torch_c.to_builtin_tensor %663 : !torch.vtensor<[],si64> -> tensor + %674 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%670, %cast, %672, %673) : (tensor, tensor, tensor, tensor) -> tensor + %cast_229 = tensor.cast %674 : tensor to tensor<4x?x8x32x128xf16> + %675 = torch_c.from_builtin_tensor %cast_229 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %675, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %676 = torch_c.to_builtin_tensor %669 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %677 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_230 = tensor.cast %677 : tensor<4x?xi64> to tensor + %678 = torch_c.to_builtin_tensor %659 : !torch.vtensor<[],si64> -> tensor + %679 = torch_c.to_builtin_tensor %667 : !torch.vtensor<[],si64> -> tensor + %680 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%676, %cast_230, %678, %679) : (tensor, tensor, tensor, tensor) -> tensor + %cast_231 = tensor.cast %680 : tensor to tensor<4x?x8x32x128xf16> + %681 = torch_c.from_builtin_tensor %cast_231 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %681, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_232 = torch.constant.int 2 + %int3_233 = torch.constant.int 3 + %682 = torch.aten.transpose.int %675, %int2_232, %int3_233 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %682, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_234 = torch.constant.int 0 + %683 = torch.aten.clone %682, %int0_234 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %683, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int4_235 = torch.constant.int 4 - %int128_236 = torch.constant.int 128 - %539 = torch.prim.ListConstruct %int4_233, %538, %int8_234, %int4_235, %int128_236 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_237 = torch.constant.bool false - %540 = torch.aten.expand %537, %539, %false_237 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %540, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_238 = torch.constant.int 0 - %541 = torch.aten.clone %540, %int0_238 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %541, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_239 = torch.constant.int 4 - %int32_240 = torch.constant.int 32 - %int128_241 = torch.constant.int 128 - %542 = torch.prim.ListConstruct %int4_239, %538, %int32_240, %int128_241 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %543 = torch.aten._unsafe_view %541, %542 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %543, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_242 = torch.constant.int -2 - %544 = torch.aten.unsqueeze %536, %int-2_242 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %544, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_243 = torch.constant.int 1 - %545 = torch.aten.size.int %535, %int1_243 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_244 = torch.constant.int 4 - %int8_245 = torch.constant.int 8 - %int4_246 = torch.constant.int 4 - %int128_247 = torch.constant.int 128 - %546 = torch.prim.ListConstruct %int4_244, %545, %int8_245, %int4_246, %int128_247 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_248 = torch.constant.bool false - %547 = torch.aten.expand %544, %546, %false_248 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %547, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_249 = torch.constant.int 0 - %548 = torch.aten.clone %547, %int0_249 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %548, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_250 = torch.constant.int 4 - %int32_251 = torch.constant.int 32 - %int128_252 = torch.constant.int 128 - %549 = torch.prim.ListConstruct %int4_250, %545, %int32_251, %int128_252 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %550 = torch.aten._unsafe_view %548, %549 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %550, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_253 = torch.constant.int 1 - %int2_254 = torch.constant.int 2 - %551 = torch.aten.transpose.int %430, %int1_253, %int2_254 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_255 = torch.constant.int 1 - %int2_256 = torch.constant.int 2 - %552 = torch.aten.transpose.int %543, %int1_255, %int2_256 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %552, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_257 = torch.constant.int 1 - %int2_258 = torch.constant.int 2 - %553 = torch.aten.transpose.int %550, %int1_257, %int2_258 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %553, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00 = torch.constant.float 0.000000e+00 + %int8_236 = torch.constant.int 8 + %int128_237 = torch.constant.int 128 + %684 = torch.prim.ListConstruct %int4_235, %457, %int8_236, %int128_237 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %685 = torch.aten._unsafe_view %683, %684 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %685, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_238 = torch.constant.int 2 + %int3_239 = torch.constant.int 3 + %686 = torch.aten.transpose.int %681, %int2_238, %int3_239 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %686, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_240 = torch.constant.int 0 + %687 = torch.aten.clone %686, %int0_240 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %687, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_241 = torch.constant.int 4 + %int8_242 = torch.constant.int 8 + %int128_243 = torch.constant.int 128 + %688 = torch.prim.ListConstruct %int4_241, %457, %int8_242, %int128_243 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %689 = torch.aten._unsafe_view %687, %688 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %689, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_244 = torch.constant.int -2 + %690 = torch.aten.unsqueeze %685, %int-2_244 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %690, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_245 = torch.constant.int 4 + %int8_246 = torch.constant.int 8 + %int4_247 = torch.constant.int 4 + %int128_248 = torch.constant.int 128 + %691 = torch.prim.ListConstruct %int4_245, %457, %int8_246, %int4_247, %int128_248 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_249 = torch.constant.bool false + %692 = torch.aten.expand %690, %691, %false_249 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %692, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_250 = torch.constant.int 0 + %693 = torch.aten.clone %692, %int0_250 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %693, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_251 = torch.constant.int 4 + %int32_252 = torch.constant.int 32 + %int128_253 = torch.constant.int 128 + %694 = torch.prim.ListConstruct %int4_251, %457, %int32_252, %int128_253 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %695 = torch.aten._unsafe_view %693, %694 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %695, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_254 = torch.constant.int -2 + %696 = torch.aten.unsqueeze %689, %int-2_254 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %696, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_255 = torch.constant.int 4 + %int8_256 = torch.constant.int 8 + %int4_257 = torch.constant.int 4 + %int128_258 = torch.constant.int 128 + %697 = torch.prim.ListConstruct %int4_255, %457, %int8_256, %int4_257, %int128_258 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %false_259 = torch.constant.bool false - %none_260 = torch.constant.none - %554:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%551, %552, %553, %float0.000000e00, %false_259, %368, %none_260) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_261 = torch.constant.int 1 - %int2_262 = torch.constant.int 2 - %555 = torch.aten.transpose.int %554#0, %int1_261, %int2_262 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_263 = torch.constant.int 4 + %698 = torch.aten.expand %696, %697, %false_259 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %698, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_260 = torch.constant.int 0 + %699 = torch.aten.clone %698, %int0_260 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %699, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_261 = torch.constant.int 4 + %int32_262 = torch.constant.int 32 + %int128_263 = torch.constant.int 128 + %700 = torch.prim.ListConstruct %int4_261, %457, %int32_262, %int128_263 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %701 = torch.aten._unsafe_view %699, %700 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %701, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int1_264 = torch.constant.int 1 - %int4096_265 = torch.constant.int 4096 - %556 = torch.prim.ListConstruct %int4_263, %int1_264, %int4096_265 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %557 = torch.aten.view %555, %556 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_266 = torch.constant.int -2 - %int-1_267 = torch.constant.int -1 - %558 = torch.aten.transpose.int %7, %int-2_266, %int-1_267 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_268 = torch.constant.int 4 - %int4096_269 = torch.constant.int 4096 - %559 = torch.prim.ListConstruct %int4_268, %int4096_269 : (!torch.int, !torch.int) -> !torch.list - %560 = torch.aten.view %557, %559 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %561 = torch.aten.mm %560, %558 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_270 = torch.constant.int 4 - %int1_271 = torch.constant.int 1 - %int4096_272 = torch.constant.int 4096 - %562 = torch.prim.ListConstruct %int4_270, %int1_271, %int4096_272 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %563 = torch.aten.view %561, %562 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_273 = torch.constant.int 1 - %564 = torch.aten.add.Tensor %390, %563, %int1_273 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_274 = torch.constant.int 6 - %565 = torch.prims.convert_element_type %564, %int6_274 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_275 = torch.constant.int 2 - %566 = torch.aten.pow.Tensor_Scalar %565, %int2_275 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_276 = torch.constant.int -1 - %567 = torch.prim.ListConstruct %int-1_276 : (!torch.int) -> !torch.list - %true_277 = torch.constant.bool true - %none_278 = torch.constant.none - %568 = torch.aten.mean.dim %566, %567, %true_277, %none_278 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_279 = torch.constant.float 9.9999997473787516E-6 - %int1_280 = torch.constant.int 1 - %569 = torch.aten.add.Scalar %568, %float9.999990e-06_279, %int1_280 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %570 = torch.aten.rsqrt %569 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %571 = torch.aten.mul.Tensor %565, %570 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_281 = torch.constant.int 5 - %572 = torch.prims.convert_element_type %571, %int5_281 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %573 = torch.aten.mul.Tensor %8, %572 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_282 = torch.constant.int 5 - %574 = torch.prims.convert_element_type %573, %int5_282 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_283 = torch.constant.int -2 - %int-1_284 = torch.constant.int -1 - %575 = torch.aten.transpose.int %9, %int-2_283, %int-1_284 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_285 = torch.constant.int 4 - %int4096_286 = torch.constant.int 4096 - %576 = torch.prim.ListConstruct %int4_285, %int4096_286 : (!torch.int, !torch.int) -> !torch.list - %577 = torch.aten.view %574, %576 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %578 = torch.aten.mm %577, %575 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_287 = torch.constant.int 4 - %int1_288 = torch.constant.int 1 - %int14336 = torch.constant.int 14336 - %579 = torch.prim.ListConstruct %int4_287, %int1_288, %int14336 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %580 = torch.aten.view %578, %579 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %581 = torch.aten.silu %580 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_289 = torch.constant.int -2 - %int-1_290 = torch.constant.int -1 - %582 = torch.aten.transpose.int %10, %int-2_289, %int-1_290 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_291 = torch.constant.int 4 - %int4096_292 = torch.constant.int 4096 - %583 = torch.prim.ListConstruct %int4_291, %int4096_292 : (!torch.int, !torch.int) -> !torch.list - %584 = torch.aten.view %574, %583 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %585 = torch.aten.mm %584, %582 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_293 = torch.constant.int 4 - %int1_294 = torch.constant.int 1 - %int14336_295 = torch.constant.int 14336 - %586 = torch.prim.ListConstruct %int4_293, %int1_294, %int14336_295 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %587 = torch.aten.view %585, %586 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %588 = torch.aten.mul.Tensor %581, %587 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_296 = torch.constant.int -2 - %int-1_297 = torch.constant.int -1 - %589 = torch.aten.transpose.int %11, %int-2_296, %int-1_297 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int2_265 = torch.constant.int 2 + %702 = torch.aten.transpose.int %580, %int1_264, %int2_265 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_266 = torch.constant.int 1 + %int2_267 = torch.constant.int 2 + %703 = torch.aten.transpose.int %695, %int1_266, %int2_267 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %703, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_268 = torch.constant.int 1 + %int2_269 = torch.constant.int 2 + %704 = torch.aten.transpose.int %701, %int1_268, %int2_269 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %704, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00 = torch.constant.float 0.000000e+00 + %false_270 = torch.constant.bool false + %none_271 = torch.constant.none + %705:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%702, %703, %704, %float0.000000e00, %false_270, %470, %none_271) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_272 = torch.constant.int 1 + %int2_273 = torch.constant.int 2 + %706 = torch.aten.transpose.int %705#0, %int1_272, %int2_273 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_274 = torch.constant.int 4 + %int1_275 = torch.constant.int 1 + %int4096_276 = torch.constant.int 4096 + %707 = torch.prim.ListConstruct %int4_274, %int1_275, %int4096_276 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %708 = torch.aten.view %706, %707 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_277 = torch.constant.int -2 + %int-1_278 = torch.constant.int -1 + %709 = torch.aten.transpose.int %11, %int-2_277, %int-1_278 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_279 = torch.constant.int 5 + %710 = torch.prims.convert_element_type %709, %int5_279 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_280 = torch.constant.int 4 + %int4096_281 = torch.constant.int 4096 + %711 = torch.prim.ListConstruct %int4_280, %int4096_281 : (!torch.int, !torch.int) -> !torch.list + %712 = torch.aten.view %708, %711 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %713 = torch.aten.mm %712, %710 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_282 = torch.constant.int 4 + %int1_283 = torch.constant.int 1 + %int4096_284 = torch.constant.int 4096 + %714 = torch.prim.ListConstruct %int4_282, %int1_283, %int4096_284 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %715 = torch.aten.view %713, %714 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_285 = torch.constant.int 1 + %716 = torch.aten.add.Tensor %533, %715, %int1_285 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_286 = torch.constant.int 6 + %717 = torch.prims.convert_element_type %716, %int6_286 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_287 = torch.constant.int 2 + %718 = torch.aten.pow.Tensor_Scalar %717, %int2_287 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_288 = torch.constant.int -1 + %719 = torch.prim.ListConstruct %int-1_288 : (!torch.int) -> !torch.list + %true_289 = torch.constant.bool true + %none_290 = torch.constant.none + %720 = torch.aten.mean.dim %718, %719, %true_289, %none_290 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_291 = torch.constant.float 9.9999997473787516E-6 + %int1_292 = torch.constant.int 1 + %721 = torch.aten.add.Scalar %720, %float9.999990e-06_291, %int1_292 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %722 = torch.aten.rsqrt %721 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %723 = torch.aten.mul.Tensor %717, %722 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_293 = torch.constant.int 5 + %724 = torch.prims.convert_element_type %723, %int5_293 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %725 = torch.aten.mul.Tensor %12, %724 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_294 = torch.constant.int 5 + %726 = torch.prims.convert_element_type %725, %int5_294 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_295 = torch.constant.int -2 + %int-1_296 = torch.constant.int -1 + %727 = torch.aten.transpose.int %13, %int-2_295, %int-1_296 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_297 = torch.constant.int 5 + %728 = torch.prims.convert_element_type %727, %int5_297 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> %int4_298 = torch.constant.int 4 - %int14336_299 = torch.constant.int 14336 - %590 = torch.prim.ListConstruct %int4_298, %int14336_299 : (!torch.int, !torch.int) -> !torch.list - %591 = torch.aten.view %588, %590 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %592 = torch.aten.mm %591, %589 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4096_299 = torch.constant.int 4096 + %729 = torch.prim.ListConstruct %int4_298, %int4096_299 : (!torch.int, !torch.int) -> !torch.list + %730 = torch.aten.view %726, %729 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %731 = torch.aten.mm %730, %728 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> %int4_300 = torch.constant.int 4 %int1_301 = torch.constant.int 1 - %int4096_302 = torch.constant.int 4096 - %593 = torch.prim.ListConstruct %int4_300, %int1_301, %int4096_302 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %594 = torch.aten.view %592, %593 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_303 = torch.constant.int 1 - %595 = torch.aten.add.Tensor %564, %594, %int1_303 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_304 = torch.constant.int 6 - %596 = torch.prims.convert_element_type %595, %int6_304 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_305 = torch.constant.int 2 - %597 = torch.aten.pow.Tensor_Scalar %596, %int2_305 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_306 = torch.constant.int -1 - %598 = torch.prim.ListConstruct %int-1_306 : (!torch.int) -> !torch.list - %true_307 = torch.constant.bool true - %none_308 = torch.constant.none - %599 = torch.aten.mean.dim %597, %598, %true_307, %none_308 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_309 = torch.constant.float 9.9999997473787516E-6 - %int1_310 = torch.constant.int 1 - %600 = torch.aten.add.Scalar %599, %float9.999990e-06_309, %int1_310 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %601 = torch.aten.rsqrt %600 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %602 = torch.aten.mul.Tensor %596, %601 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_311 = torch.constant.int 5 - %603 = torch.prims.convert_element_type %602, %int5_311 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %604 = torch.aten.mul.Tensor %12, %603 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int14336 = torch.constant.int 14336 + %732 = torch.prim.ListConstruct %int4_300, %int1_301, %int14336 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %733 = torch.aten.view %731, %732 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %734 = torch.aten.silu %733 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_302 = torch.constant.int -2 + %int-1_303 = torch.constant.int -1 + %735 = torch.aten.transpose.int %14, %int-2_302, %int-1_303 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_304 = torch.constant.int 5 + %736 = torch.prims.convert_element_type %735, %int5_304 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_305 = torch.constant.int 4 + %int4096_306 = torch.constant.int 4096 + %737 = torch.prim.ListConstruct %int4_305, %int4096_306 : (!torch.int, !torch.int) -> !torch.list + %738 = torch.aten.view %726, %737 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %739 = torch.aten.mm %738, %736 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_307 = torch.constant.int 4 + %int1_308 = torch.constant.int 1 + %int14336_309 = torch.constant.int 14336 + %740 = torch.prim.ListConstruct %int4_307, %int1_308, %int14336_309 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %741 = torch.aten.view %739, %740 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %742 = torch.aten.mul.Tensor %734, %741 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_310 = torch.constant.int -2 + %int-1_311 = torch.constant.int -1 + %743 = torch.aten.transpose.int %15, %int-2_310, %int-1_311 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> %int5_312 = torch.constant.int 5 - %605 = torch.prims.convert_element_type %604, %int5_312 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_313 = torch.constant.int -2 - %int-1_314 = torch.constant.int -1 - %606 = torch.aten.transpose.int %13, %int-2_313, %int-1_314 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %744 = torch.prims.convert_element_type %743, %int5_312 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_313 = torch.constant.int 4 + %int14336_314 = torch.constant.int 14336 + %745 = torch.prim.ListConstruct %int4_313, %int14336_314 : (!torch.int, !torch.int) -> !torch.list + %746 = torch.aten.view %742, %745 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %747 = torch.aten.mm %746, %744 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_315 = torch.constant.int 4 - %int4096_316 = torch.constant.int 4096 - %607 = torch.prim.ListConstruct %int4_315, %int4096_316 : (!torch.int, !torch.int) -> !torch.list - %608 = torch.aten.view %605, %607 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %609 = torch.aten.mm %608, %606 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_317 = torch.constant.int 4 + %int1_316 = torch.constant.int 1 + %int4096_317 = torch.constant.int 4096 + %748 = torch.prim.ListConstruct %int4_315, %int1_316, %int4096_317 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %749 = torch.aten.view %747, %748 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_318 = torch.constant.int 1 - %int4096_319 = torch.constant.int 4096 - %610 = torch.prim.ListConstruct %int4_317, %int1_318, %int4096_319 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %611 = torch.aten.view %609, %610 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_320 = torch.constant.int -2 + %750 = torch.aten.add.Tensor %716, %749, %int1_318 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_319 = torch.constant.int 6 + %751 = torch.prims.convert_element_type %750, %int6_319 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_320 = torch.constant.int 2 + %752 = torch.aten.pow.Tensor_Scalar %751, %int2_320 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> %int-1_321 = torch.constant.int -1 - %612 = torch.aten.transpose.int %14, %int-2_320, %int-1_321 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_322 = torch.constant.int 4 - %int4096_323 = torch.constant.int 4096 - %613 = torch.prim.ListConstruct %int4_322, %int4096_323 : (!torch.int, !torch.int) -> !torch.list - %614 = torch.aten.view %605, %613 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %615 = torch.aten.mm %614, %612 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_324 = torch.constant.int 4 + %753 = torch.prim.ListConstruct %int-1_321 : (!torch.int) -> !torch.list + %true_322 = torch.constant.bool true + %none_323 = torch.constant.none + %754 = torch.aten.mean.dim %752, %753, %true_322, %none_323 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_324 = torch.constant.float 9.9999997473787516E-6 %int1_325 = torch.constant.int 1 - %int1024_326 = torch.constant.int 1024 - %616 = torch.prim.ListConstruct %int4_324, %int1_325, %int1024_326 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %617 = torch.aten.view %615, %616 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_327 = torch.constant.int -2 - %int-1_328 = torch.constant.int -1 - %618 = torch.aten.transpose.int %15, %int-2_327, %int-1_328 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_329 = torch.constant.int 4 - %int4096_330 = torch.constant.int 4096 - %619 = torch.prim.ListConstruct %int4_329, %int4096_330 : (!torch.int, !torch.int) -> !torch.list - %620 = torch.aten.view %605, %619 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %621 = torch.aten.mm %620, %618 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %755 = torch.aten.add.Scalar %754, %float9.999990e-06_324, %int1_325 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %756 = torch.aten.rsqrt %755 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %757 = torch.aten.mul.Tensor %751, %756 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_326 = torch.constant.int 5 + %758 = torch.prims.convert_element_type %757, %int5_326 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %759 = torch.aten.mul.Tensor %16, %758 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_327 = torch.constant.int 5 + %760 = torch.prims.convert_element_type %759, %int5_327 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_328 = torch.constant.int -2 + %int-1_329 = torch.constant.int -1 + %761 = torch.aten.transpose.int %17, %int-2_328, %int-1_329 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_330 = torch.constant.int 5 + %762 = torch.prims.convert_element_type %761, %int5_330 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> %int4_331 = torch.constant.int 4 - %int1_332 = torch.constant.int 1 - %int1024_333 = torch.constant.int 1024 - %622 = torch.prim.ListConstruct %int4_331, %int1_332, %int1024_333 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %623 = torch.aten.view %621, %622 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_334 = torch.constant.int 4 - %int1_335 = torch.constant.int 1 - %int32_336 = torch.constant.int 32 - %int128_337 = torch.constant.int 128 - %624 = torch.prim.ListConstruct %int4_334, %int1_335, %int32_336, %int128_337 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %625 = torch.aten.view %611, %624 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_338 = torch.constant.int 4 - %int1_339 = torch.constant.int 1 - %int8_340 = torch.constant.int 8 - %int128_341 = torch.constant.int 128 - %626 = torch.prim.ListConstruct %int4_338, %int1_339, %int8_340, %int128_341 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %627 = torch.aten.view %617, %626 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_342 = torch.constant.int 4 - %int1_343 = torch.constant.int 1 - %int8_344 = torch.constant.int 8 - %int128_345 = torch.constant.int 128 - %628 = torch.prim.ListConstruct %int4_342, %int1_343, %int8_344, %int128_345 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %629 = torch.aten.view %623, %628 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_346 = torch.constant.int 6 - %630 = torch.prims.convert_element_type %625, %int6_346 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %631 = torch_c.to_builtin_tensor %630 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %632 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %633 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%631, %632) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %634 = torch_c.from_builtin_tensor %633 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_347 = torch.constant.int 5 - %635 = torch.prims.convert_element_type %634, %int5_347 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_348 = torch.constant.int 6 - %636 = torch.prims.convert_element_type %627, %int6_348 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %637 = torch_c.to_builtin_tensor %636 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %638 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %639 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%637, %638) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %640 = torch_c.from_builtin_tensor %639 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_349 = torch.constant.int 5 - %641 = torch.prims.convert_element_type %640, %int5_349 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_350 = torch.constant.int 32 - %642 = torch.aten.floor_divide.Scalar %arg2, %int32_350 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_351 = torch.constant.int 1 - %643 = torch.aten.unsqueeze %642, %int1_351 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_352 = torch.constant.int 1 - %false_353 = torch.constant.bool false - %644 = torch.aten.gather %arg3, %int1_352, %643, %false_353 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4096_332 = torch.constant.int 4096 + %763 = torch.prim.ListConstruct %int4_331, %int4096_332 : (!torch.int, !torch.int) -> !torch.list + %764 = torch.aten.view %760, %763 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %765 = torch.aten.mm %764, %762 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_333 = torch.constant.int 4 + %int1_334 = torch.constant.int 1 + %int4096_335 = torch.constant.int 4096 + %766 = torch.prim.ListConstruct %int4_333, %int1_334, %int4096_335 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %767 = torch.aten.view %765, %766 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_336 = torch.constant.int -2 + %int-1_337 = torch.constant.int -1 + %768 = torch.aten.transpose.int %18, %int-2_336, %int-1_337 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_338 = torch.constant.int 5 + %769 = torch.prims.convert_element_type %768, %int5_338 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_339 = torch.constant.int 4 + %int4096_340 = torch.constant.int 4096 + %770 = torch.prim.ListConstruct %int4_339, %int4096_340 : (!torch.int, !torch.int) -> !torch.list + %771 = torch.aten.view %760, %770 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %772 = torch.aten.mm %771, %769 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_341 = torch.constant.int 4 + %int1_342 = torch.constant.int 1 + %int1024_343 = torch.constant.int 1024 + %773 = torch.prim.ListConstruct %int4_341, %int1_342, %int1024_343 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %774 = torch.aten.view %772, %773 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_344 = torch.constant.int -2 + %int-1_345 = torch.constant.int -1 + %775 = torch.aten.transpose.int %19, %int-2_344, %int-1_345 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_346 = torch.constant.int 5 + %776 = torch.prims.convert_element_type %775, %int5_346 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_347 = torch.constant.int 4 + %int4096_348 = torch.constant.int 4096 + %777 = torch.prim.ListConstruct %int4_347, %int4096_348 : (!torch.int, !torch.int) -> !torch.list + %778 = torch.aten.view %760, %777 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %779 = torch.aten.mm %778, %776 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_349 = torch.constant.int 4 + %int1_350 = torch.constant.int 1 + %int1024_351 = torch.constant.int 1024 + %780 = torch.prim.ListConstruct %int4_349, %int1_350, %int1024_351 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %781 = torch.aten.view %779, %780 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_352 = torch.constant.int 4 + %int1_353 = torch.constant.int 1 %int32_354 = torch.constant.int 32 - %645 = torch.aten.remainder.Scalar %arg2, %int32_354 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_355 = torch.constant.int 1 - %646 = torch.aten.unsqueeze %645, %int1_355 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_356 = torch.constant.none - %647 = torch.aten.clone %16, %none_356 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_357 = torch.constant.int 0 - %648 = torch.aten.unsqueeze %647, %int0_357 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_358 = torch.constant.int 4 - %int1_359 = torch.constant.int 1 - %649 = torch.prim.ListConstruct %int4_358, %int1_359 : (!torch.int, !torch.int) -> !torch.list - %int1_360 = torch.constant.int 1 + %int128_355 = torch.constant.int 128 + %782 = torch.prim.ListConstruct %int4_352, %int1_353, %int32_354, %int128_355 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %783 = torch.aten.view %767, %782 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_356 = torch.constant.int 4 + %int1_357 = torch.constant.int 1 + %int8_358 = torch.constant.int 8 + %int128_359 = torch.constant.int 128 + %784 = torch.prim.ListConstruct %int4_356, %int1_357, %int8_358, %int128_359 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %785 = torch.aten.view %774, %784 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_360 = torch.constant.int 4 %int1_361 = torch.constant.int 1 - %650 = torch.prim.ListConstruct %int1_360, %int1_361 : (!torch.int, !torch.int) -> !torch.list - %int4_362 = torch.constant.int 4 - %int0_363 = torch.constant.int 0 - %cpu_364 = torch.constant.device "cpu" - %false_365 = torch.constant.bool false - %651 = torch.aten.empty_strided %649, %650, %int4_362, %int0_363, %cpu_364, %false_365 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int1_366 = torch.constant.int 1 - %652 = torch.aten.fill.Scalar %651, %int1_366 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_367 = torch.constant.int 4 - %int1_368 = torch.constant.int 1 - %653 = torch.prim.ListConstruct %int4_367, %int1_368 : (!torch.int, !torch.int) -> !torch.list - %654 = torch.aten.repeat %648, %653 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_369 = torch.constant.int 32 - %655 = torch.aten.mul.Scalar %644, %int32_369 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_370 = torch.constant.int 1 - %656 = torch.aten.add.Tensor %655, %652, %int1_370 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_371 = torch.constant.int 2 - %657 = torch.aten.mul.Scalar %656, %int2_371 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_372 = torch.constant.int 1 - %658 = torch.aten.add.Tensor %657, %654, %int1_372 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_373 = torch.constant.int 32 - %659 = torch.aten.mul.Scalar %658, %int32_373 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_374 = torch.constant.int 1 - %660 = torch.aten.add.Tensor %659, %646, %int1_374 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_375 = torch.constant.int 32 - %int2_376 = torch.constant.int 2 - %int32_377 = torch.constant.int 32 - %int8_378 = torch.constant.int 8 - %int128_379 = torch.constant.int 128 - %661 = torch.prim.ListConstruct %437, %int32_375, %int2_376, %int32_377, %int8_378, %int128_379 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %662 = torch.aten.view %498, %661 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %662, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_380 = torch.constant.int 32 - %663 = torch.aten.mul.int %437, %int32_380 : !torch.int, !torch.int -> !torch.int - %int2_381 = torch.constant.int 2 - %664 = torch.aten.mul.int %663, %int2_381 : !torch.int, !torch.int -> !torch.int - %int32_382 = torch.constant.int 32 - %665 = torch.aten.mul.int %664, %int32_382 : !torch.int, !torch.int -> !torch.int - %int8_383 = torch.constant.int 8 - %int128_384 = torch.constant.int 128 - %666 = torch.prim.ListConstruct %665, %int8_383, %int128_384 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %667 = torch.aten.view %662, %666 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %667, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %668 = torch.prim.ListConstruct %660 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_385 = torch.constant.bool false - %669 = torch.aten.index_put %667, %668, %641, %false_385 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %669, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_386 = torch.constant.int 32 - %int2_387 = torch.constant.int 2 - %int32_388 = torch.constant.int 32 - %int8_389 = torch.constant.int 8 - %int128_390 = torch.constant.int 128 - %670 = torch.prim.ListConstruct %437, %int32_386, %int2_387, %int32_388, %int8_389, %int128_390 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %671 = torch.aten.view %669, %670 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %671, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_391 = torch.constant.int 2097152 - %672 = torch.prim.ListConstruct %437, %int2097152_391 : (!torch.int, !torch.int) -> !torch.list - %673 = torch.aten.view %671, %672 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %673, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int8_362 = torch.constant.int 8 + %int128_363 = torch.constant.int 128 + %786 = torch.prim.ListConstruct %int4_360, %int1_361, %int8_362, %int128_363 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %787 = torch.aten.view %781, %786 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_364 = torch.constant.int 1 + %int2_365 = torch.constant.int 2 + %788 = torch.aten.transpose.int %783, %int1_364, %int2_365 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %789 = torch.aten.mul.Tensor %788, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_366 = torch.constant.int 3 + %int0_367 = torch.constant.int 0 + %int64_368 = torch.constant.int 64 + %int1_369 = torch.constant.int 1 + %790 = torch.aten.slice.Tensor %788, %int3_366, %int0_367, %int64_368, %int1_369 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_370 = torch.constant.int 3 + %int64_371 = torch.constant.int 64 + %int9223372036854775807_372 = torch.constant.int 9223372036854775807 + %int1_373 = torch.constant.int 1 + %791 = torch.aten.slice.Tensor %788, %int3_370, %int64_371, %int9223372036854775807_372, %int1_373 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %792 = torch.aten.neg %791 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %793 = torch.prim.ListConstruct %792, %790 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_374 = torch.constant.int -1 + %794 = torch.aten.cat %793, %int-1_374 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %795 = torch.aten.mul.Tensor %794, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_375 = torch.constant.int 1 + %796 = torch.aten.add.Tensor %789, %795, %int1_375 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_376 = torch.constant.int 1 + %int2_377 = torch.constant.int 2 + %797 = torch.aten.transpose.int %796, %int1_376, %int2_377 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_378 = torch.constant.int 1 + %int2_379 = torch.constant.int 2 + %798 = torch.aten.transpose.int %785, %int1_378, %int2_379 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %799 = torch.aten.mul.Tensor %798, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_380 = torch.constant.int 3 + %int0_381 = torch.constant.int 0 + %int64_382 = torch.constant.int 64 + %int1_383 = torch.constant.int 1 + %800 = torch.aten.slice.Tensor %798, %int3_380, %int0_381, %int64_382, %int1_383 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_384 = torch.constant.int 3 + %int64_385 = torch.constant.int 64 + %int9223372036854775807_386 = torch.constant.int 9223372036854775807 + %int1_387 = torch.constant.int 1 + %801 = torch.aten.slice.Tensor %798, %int3_384, %int64_385, %int9223372036854775807_386, %int1_387 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %802 = torch.aten.neg %801 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %803 = torch.prim.ListConstruct %802, %800 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_388 = torch.constant.int -1 + %804 = torch.aten.cat %803, %int-1_388 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %805 = torch.aten.mul.Tensor %804, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_389 = torch.constant.int 1 + %806 = torch.aten.add.Tensor %799, %805, %int1_389 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_390 = torch.constant.int 1 + %int2_391 = torch.constant.int 2 + %807 = torch.aten.transpose.int %806, %int1_390, %int2_391 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> %int32_392 = torch.constant.int 32 - %int2_393 = torch.constant.int 2 - %int32_394 = torch.constant.int 32 - %int8_395 = torch.constant.int 8 - %int128_396 = torch.constant.int 128 - %674 = torch.prim.ListConstruct %437, %int32_392, %int2_393, %int32_394, %int8_395, %int128_396 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %675 = torch.aten.view %673, %674 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %675, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_397 = torch.constant.int 8 - %int128_398 = torch.constant.int 128 - %676 = torch.prim.ListConstruct %665, %int8_397, %int128_398 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %677 = torch.aten.view %675, %676 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %677, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> + %808 = torch.aten.floor_divide.Scalar %arg2, %int32_392 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_393 = torch.constant.int 1 + %809 = torch.aten.unsqueeze %808, %int1_393 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_394 = torch.constant.int 1 + %false_395 = torch.constant.bool false + %810 = torch.aten.gather %arg3, %int1_394, %809, %false_395 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_396 = torch.constant.int 4 + %int1_397 = torch.constant.int 1 + %int1_398 = torch.constant.int 1 + %811 = torch.prim.ListConstruct %int4_396, %int1_397, %int1_398 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %812 = torch.aten.view %810, %811 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> %int32_399 = torch.constant.int 32 - %678 = torch.aten.floor_divide.Scalar %arg2, %int32_399 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_400 = torch.constant.int 1 - %679 = torch.aten.unsqueeze %678, %int1_400 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %813 = torch.aten.remainder.Scalar %arg2, %int32_399 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_400 = torch.constant.int 4 %int1_401 = torch.constant.int 1 - %false_402 = torch.constant.bool false - %680 = torch.aten.gather %arg3, %int1_401, %679, %false_402 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_403 = torch.constant.int 32 - %681 = torch.aten.remainder.Scalar %arg2, %int32_403 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_404 = torch.constant.int 1 - %682 = torch.aten.unsqueeze %681, %int1_404 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_402 = torch.constant.int 1 + %814 = torch.prim.ListConstruct %int4_400, %int1_401, %int1_402 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %815 = torch.aten.view %813, %814 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_403 = torch.constant.int 8 + %none_404 = torch.constant.none %none_405 = torch.constant.none - %683 = torch.aten.clone %17, %none_405 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_406 = torch.constant.int 0 - %684 = torch.aten.unsqueeze %683, %int0_406 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_407 = torch.constant.int 4 + %cpu_406 = torch.constant.device "cpu" + %false_407 = torch.constant.bool false + %816 = torch.aten.arange %int8_403, %none_404, %none_405, %cpu_406, %false_407 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> %int1_408 = torch.constant.int 1 - %685 = torch.prim.ListConstruct %int4_407, %int1_408 : (!torch.int, !torch.int) -> !torch.list %int1_409 = torch.constant.int 1 - %int1_410 = torch.constant.int 1 - %686 = torch.prim.ListConstruct %int1_409, %int1_410 : (!torch.int, !torch.int) -> !torch.list - %int4_411 = torch.constant.int 4 - %int0_412 = torch.constant.int 0 - %cpu_413 = torch.constant.device "cpu" - %false_414 = torch.constant.bool false - %687 = torch.aten.empty_strided %685, %686, %int4_411, %int0_412, %cpu_413, %false_414 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int1_415 = torch.constant.int 1 - %688 = torch.aten.fill.Scalar %687, %int1_415 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_416 = torch.constant.int 4 + %int8_410 = torch.constant.int 8 + %817 = torch.prim.ListConstruct %int1_408, %int1_409, %int8_410 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %818 = torch.aten.view %816, %817 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_411 = torch.constant.none + %819 = torch.aten.clone %20, %none_411 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %820 = torch.aten.detach %819 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %821 = torch.aten.detach %820 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %822 = torch.aten.detach %821 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_412 = torch.constant.int 1 + %int1_413 = torch.constant.int 1 + %int1_414 = torch.constant.int 1 + %823 = torch.prim.ListConstruct %int1_412, %int1_413, %int1_414 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %824 = torch.aten.view %822, %823 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_415 = torch.constant.int 32 + %825 = torch.aten.mul.Scalar %812, %int32_415 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_416 = torch.constant.int 1 %int1_417 = torch.constant.int 1 - %689 = torch.prim.ListConstruct %int4_416, %int1_417 : (!torch.int, !torch.int) -> !torch.list - %690 = torch.aten.repeat %684, %689 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_418 = torch.constant.int 32 - %691 = torch.aten.mul.Scalar %680, %int32_418 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %826 = torch.aten.add.Scalar %825, %int1_416, %int1_417 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_418 = torch.constant.int 2 + %827 = torch.aten.mul.Scalar %826, %int2_418 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_419 = torch.constant.int 1 - %692 = torch.aten.add.Tensor %691, %688, %int1_419 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_420 = torch.constant.int 2 - %693 = torch.aten.mul.Scalar %692, %int2_420 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %828 = torch.aten.add.Tensor %827, %824, %int1_419 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_420 = torch.constant.int 8 + %829 = torch.aten.mul.Scalar %828, %int8_420 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_421 = torch.constant.int 1 - %694 = torch.aten.add.Tensor %693, %690, %int1_421 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %830 = torch.aten.add.Tensor %829, %818, %int1_421 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int32_422 = torch.constant.int 32 - %695 = torch.aten.mul.Scalar %694, %int32_422 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %831 = torch.aten.mul.Scalar %830, %int32_422 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_423 = torch.constant.int 1 - %696 = torch.aten.add.Tensor %695, %682, %int1_423 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %697 = torch.prim.ListConstruct %696 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_424 = torch.constant.bool false - %698 = torch.aten.index_put %677, %697, %629, %false_424 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %698, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> + %832 = torch.aten.add.Tensor %831, %815, %int1_423 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_424 = torch.constant.int 5 + %833 = torch.prims.convert_element_type %807, %int5_424 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> %int32_425 = torch.constant.int 32 %int2_426 = torch.constant.int 2 - %int32_427 = torch.constant.int 32 - %int8_428 = torch.constant.int 8 + %int8_427 = torch.constant.int 8 + %int32_428 = torch.constant.int 32 %int128_429 = torch.constant.int 128 - %699 = torch.prim.ListConstruct %437, %int32_425, %int2_426, %int32_427, %int8_428, %int128_429 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %700 = torch.aten.view %698, %699 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %700, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_430 = torch.constant.int 2097152 - %701 = torch.prim.ListConstruct %437, %int2097152_430 : (!torch.int, !torch.int) -> !torch.list - %702 = torch.aten.view %700, %701 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %702, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_431 = torch.constant.int 4 - %703 = torch.prim.ListConstruct %int4_431, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_432 = torch.constant.int 1 - %704 = torch.prim.ListConstruct %358, %int1_432 : (!torch.int, !torch.int) -> !torch.list - %int4_433 = torch.constant.int 4 - %int0_434 = torch.constant.int 0 - %cpu_435 = torch.constant.device "cpu" - %false_436 = torch.constant.bool false - %705 = torch.aten.empty_strided %703, %704, %int4_433, %int0_434, %cpu_435, %false_436 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %705, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_437 = torch.constant.int 1 - %706 = torch.aten.fill.Scalar %705, %int1_437 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %706, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %834 = torch.prim.ListConstruct %456, %int32_425, %int2_426, %int8_427, %int32_428, %int128_429 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %835 = torch.aten.view %655, %834 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %835, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_430 = torch.constant.int 128 + %836 = torch.prim.ListConstruct %596, %int128_430 : (!torch.int, !torch.int) -> !torch.list + %837 = torch.aten.view %835, %836 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %837, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %838 = torch.prim.ListConstruct %832 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_431 = torch.constant.bool false + %839 = torch.aten.index_put %837, %838, %833, %false_431 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %839, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_432 = torch.constant.int 32 + %int2_433 = torch.constant.int 2 + %int8_434 = torch.constant.int 8 + %int32_435 = torch.constant.int 32 + %int128_436 = torch.constant.int 128 + %840 = torch.prim.ListConstruct %456, %int32_432, %int2_433, %int8_434, %int32_435, %int128_436 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %841 = torch.aten.view %839, %840 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %841, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_437 = torch.constant.int 2097152 + %842 = torch.prim.ListConstruct %456, %int2097152_437 : (!torch.int, !torch.int) -> !torch.list + %843 = torch.aten.view %841, %842 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %843, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> %int32_438 = torch.constant.int 32 - %707 = torch.aten.mul.Scalar %arg3, %int32_438 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %707, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_439 = torch.constant.int 1 - %708 = torch.aten.add.Tensor %707, %706, %int1_439 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %708, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_440 = torch.constant.int 4 - %709 = torch.aten.mul.int %int4_440, %358 : !torch.int, !torch.int -> !torch.int - %710 = torch.prim.ListConstruct %709 : (!torch.int) -> !torch.list - %711 = torch.aten.view %708, %710 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %711, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> + %int2_439 = torch.constant.int 2 + %int8_440 = torch.constant.int 8 %int32_441 = torch.constant.int 32 - %int2_442 = torch.constant.int 2 - %int32_443 = torch.constant.int 32 - %int8_444 = torch.constant.int 8 - %int128_445 = torch.constant.int 128 - %712 = torch.prim.ListConstruct %437, %int32_441, %int2_442, %int32_443, %int8_444, %int128_445 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %713 = torch.aten.view %702, %712 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %713, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_446 = torch.constant.int 32 - %714 = torch.aten.mul.int %437, %int32_446 : !torch.int, !torch.int -> !torch.int - %int2_447 = torch.constant.int 2 + %int128_442 = torch.constant.int 128 + %844 = torch.prim.ListConstruct %456, %int32_438, %int2_439, %int8_440, %int32_441, %int128_442 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %845 = torch.aten.view %843, %844 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %845, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_443 = torch.constant.int 128 + %846 = torch.prim.ListConstruct %596, %int128_443 : (!torch.int, !torch.int) -> !torch.list + %847 = torch.aten.view %845, %846 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %847, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_444 = torch.constant.none + %848 = torch.aten.clone %21, %none_444 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %849 = torch.aten.detach %848 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %850 = torch.aten.detach %849 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %851 = torch.aten.detach %850 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_445 = torch.constant.int 1 + %int1_446 = torch.constant.int 1 + %int1_447 = torch.constant.int 1 + %852 = torch.prim.ListConstruct %int1_445, %int1_446, %int1_447 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %853 = torch.aten.view %851, %852 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> %int32_448 = torch.constant.int 32 - %int8_449 = torch.constant.int 8 - %int128_450 = torch.constant.int 128 - %715 = torch.prim.ListConstruct %714, %int2_447, %int32_448, %int8_449, %int128_450 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %716 = torch.aten.view %713, %715 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %716, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_451 = torch.constant.int 0 - %717 = torch.aten.index_select %716, %int0_451, %711 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %717, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_452 = torch.constant.int 4 - %int2_453 = torch.constant.int 2 - %int32_454 = torch.constant.int 32 - %int8_455 = torch.constant.int 8 - %int128_456 = torch.constant.int 128 - %718 = torch.prim.ListConstruct %int4_452, %358, %int2_453, %int32_454, %int8_455, %int128_456 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %719 = torch.aten.view %717, %718 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %719, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_457 = torch.constant.int 0 - %int0_458 = torch.constant.int 0 - %int9223372036854775807_459 = torch.constant.int 9223372036854775807 - %int1_460 = torch.constant.int 1 - %720 = torch.aten.slice.Tensor %719, %int0_457, %int0_458, %int9223372036854775807_459, %int1_460 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %720, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_461 = torch.constant.int 1 - %int0_462 = torch.constant.int 0 - %int9223372036854775807_463 = torch.constant.int 9223372036854775807 - %int1_464 = torch.constant.int 1 - %721 = torch.aten.slice.Tensor %720, %int1_461, %int0_462, %int9223372036854775807_463, %int1_464 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %721, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_465 = torch.constant.int 2 - %int0_466 = torch.constant.int 0 - %722 = torch.aten.select.int %721, %int2_465, %int0_466 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %722, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_467 = torch.constant.int 32 - %723 = torch.aten.mul.int %358, %int32_467 : !torch.int, !torch.int -> !torch.int - %int2_468 = torch.constant.int 2 - %int0_469 = torch.constant.int 0 - %int1_470 = torch.constant.int 1 - %724 = torch.aten.slice.Tensor %722, %int2_468, %int0_469, %723, %int1_470 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %724, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_471 = torch.constant.int 0 - %725 = torch.aten.clone %724, %int0_471 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %725, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_472 = torch.constant.int 1 - %726 = torch.aten.size.int %721, %int1_472 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_473 = torch.constant.int 32 - %727 = torch.aten.mul.int %726, %int32_473 : !torch.int, !torch.int -> !torch.int - %int4_474 = torch.constant.int 4 - %int8_475 = torch.constant.int 8 - %int128_476 = torch.constant.int 128 - %728 = torch.prim.ListConstruct %int4_474, %727, %int8_475, %int128_476 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %729 = torch.aten._unsafe_view %725, %728 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %729, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_477 = torch.constant.int 0 - %int0_478 = torch.constant.int 0 - %int9223372036854775807_479 = torch.constant.int 9223372036854775807 - %int1_480 = torch.constant.int 1 - %730 = torch.aten.slice.Tensor %729, %int0_477, %int0_478, %int9223372036854775807_479, %int1_480 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %730, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_481 = torch.constant.int 0 - %int0_482 = torch.constant.int 0 - %int9223372036854775807_483 = torch.constant.int 9223372036854775807 - %int1_484 = torch.constant.int 1 - %731 = torch.aten.slice.Tensor %719, %int0_481, %int0_482, %int9223372036854775807_483, %int1_484 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %731, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_485 = torch.constant.int 1 - %int0_486 = torch.constant.int 0 - %int9223372036854775807_487 = torch.constant.int 9223372036854775807 - %int1_488 = torch.constant.int 1 - %732 = torch.aten.slice.Tensor %731, %int1_485, %int0_486, %int9223372036854775807_487, %int1_488 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %732, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_489 = torch.constant.int 2 - %int1_490 = torch.constant.int 1 - %733 = torch.aten.select.int %732, %int2_489, %int1_490 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %733, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_491 = torch.constant.int 2 - %int0_492 = torch.constant.int 0 - %int1_493 = torch.constant.int 1 - %734 = torch.aten.slice.Tensor %733, %int2_491, %int0_492, %723, %int1_493 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %734, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_494 = torch.constant.int 0 - %735 = torch.aten.clone %734, %int0_494 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %735, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_495 = torch.constant.int 1 - %736 = torch.aten.size.int %732, %int1_495 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_496 = torch.constant.int 32 - %737 = torch.aten.mul.int %736, %int32_496 : !torch.int, !torch.int -> !torch.int - %int4_497 = torch.constant.int 4 - %int8_498 = torch.constant.int 8 - %int128_499 = torch.constant.int 128 - %738 = torch.prim.ListConstruct %int4_497, %737, %int8_498, %int128_499 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %739 = torch.aten._unsafe_view %735, %738 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %739, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_500 = torch.constant.int 0 - %int0_501 = torch.constant.int 0 - %int9223372036854775807_502 = torch.constant.int 9223372036854775807 - %int1_503 = torch.constant.int 1 - %740 = torch.aten.slice.Tensor %739, %int0_500, %int0_501, %int9223372036854775807_502, %int1_503 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %740, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_504 = torch.constant.int -2 - %741 = torch.aten.unsqueeze %730, %int-2_504 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %741, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_505 = torch.constant.int 1 - %742 = torch.aten.size.int %729, %int1_505 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int + %854 = torch.aten.mul.Scalar %812, %int32_448 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_449 = torch.constant.int 1 + %int1_450 = torch.constant.int 1 + %855 = torch.aten.add.Scalar %854, %int1_449, %int1_450 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_451 = torch.constant.int 2 + %856 = torch.aten.mul.Scalar %855, %int2_451 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_452 = torch.constant.int 1 + %857 = torch.aten.add.Tensor %856, %853, %int1_452 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_453 = torch.constant.int 8 + %858 = torch.aten.mul.Scalar %857, %int8_453 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_454 = torch.constant.int 1 + %859 = torch.aten.add.Tensor %858, %818, %int1_454 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_455 = torch.constant.int 32 + %860 = torch.aten.mul.Scalar %859, %int32_455 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_456 = torch.constant.int 1 + %861 = torch.aten.add.Tensor %860, %815, %int1_456 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_457 = torch.constant.int 5 + %862 = torch.prims.convert_element_type %787, %int5_457 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %863 = torch.prim.ListConstruct %861 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_458 = torch.constant.bool false + %864 = torch.aten.index_put %847, %863, %862, %false_458 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %864, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_459 = torch.constant.int 32 + %int2_460 = torch.constant.int 2 + %int8_461 = torch.constant.int 8 + %int32_462 = torch.constant.int 32 + %int128_463 = torch.constant.int 128 + %865 = torch.prim.ListConstruct %456, %int32_459, %int2_460, %int8_461, %int32_462, %int128_463 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %866 = torch.aten.view %864, %865 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %866, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_464 = torch.constant.int 2097152 + %867 = torch.prim.ListConstruct %456, %int2097152_464 : (!torch.int, !torch.int) -> !torch.list + %868 = torch.aten.view %866, %867 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %868, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_465 = torch.constant.none + %869 = torch.aten.clone %22, %none_465 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %870 = torch.aten.detach %869 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %871 = torch.aten.detach %870 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %872 = torch.aten.detach %871 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_466 = torch.constant.none + %873 = torch.aten.clone %23, %none_466 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %874 = torch.aten.detach %873 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %875 = torch.aten.detach %874 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %876 = torch.aten.detach %875 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_467 = torch.constant.none + %877 = torch.aten.clone %24, %none_467 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %878 = torch.aten.detach %877 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %879 = torch.aten.detach %878 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %880 = torch.aten.detach %879 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_468 = torch.constant.int 32 + %int2_469 = torch.constant.int 2 + %int8_470 = torch.constant.int 8 + %int32_471 = torch.constant.int 32 + %int128_472 = torch.constant.int 128 + %881 = torch.prim.ListConstruct %456, %int32_468, %int2_469, %int8_470, %int32_471, %int128_472 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %882 = torch.aten.view %868, %881 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %882, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %883 = torch_c.to_builtin_tensor %882 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %884 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_473 = tensor.cast %884 : tensor<4x?xi64> to tensor + %885 = torch_c.to_builtin_tensor %872 : !torch.vtensor<[],si64> -> tensor + %886 = torch_c.to_builtin_tensor %876 : !torch.vtensor<[],si64> -> tensor + %887 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%883, %cast_473, %885, %886) : (tensor, tensor, tensor, tensor) -> tensor + %cast_474 = tensor.cast %887 : tensor to tensor<4x?x8x32x128xf16> + %888 = torch_c.from_builtin_tensor %cast_474 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %888, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %889 = torch_c.to_builtin_tensor %882 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %890 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_475 = tensor.cast %890 : tensor<4x?xi64> to tensor + %891 = torch_c.to_builtin_tensor %872 : !torch.vtensor<[],si64> -> tensor + %892 = torch_c.to_builtin_tensor %880 : !torch.vtensor<[],si64> -> tensor + %893 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%889, %cast_475, %891, %892) : (tensor, tensor, tensor, tensor) -> tensor + %cast_476 = tensor.cast %893 : tensor to tensor<4x?x8x32x128xf16> + %894 = torch_c.from_builtin_tensor %cast_476 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %894, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_477 = torch.constant.int 2 + %int3_478 = torch.constant.int 3 + %895 = torch.aten.transpose.int %888, %int2_477, %int3_478 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %895, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_479 = torch.constant.int 0 + %896 = torch.aten.clone %895, %int0_479 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %896, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_480 = torch.constant.int 4 + %int8_481 = torch.constant.int 8 + %int128_482 = torch.constant.int 128 + %897 = torch.prim.ListConstruct %int4_480, %457, %int8_481, %int128_482 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %898 = torch.aten._unsafe_view %896, %897 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %898, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_483 = torch.constant.int 2 + %int3_484 = torch.constant.int 3 + %899 = torch.aten.transpose.int %894, %int2_483, %int3_484 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %899, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_485 = torch.constant.int 0 + %900 = torch.aten.clone %899, %int0_485 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %900, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_486 = torch.constant.int 4 + %int8_487 = torch.constant.int 8 + %int128_488 = torch.constant.int 128 + %901 = torch.prim.ListConstruct %int4_486, %457, %int8_487, %int128_488 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %902 = torch.aten._unsafe_view %900, %901 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %902, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_489 = torch.constant.int -2 + %903 = torch.aten.unsqueeze %898, %int-2_489 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %903, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_490 = torch.constant.int 4 + %int8_491 = torch.constant.int 8 + %int4_492 = torch.constant.int 4 + %int128_493 = torch.constant.int 128 + %904 = torch.prim.ListConstruct %int4_490, %457, %int8_491, %int4_492, %int128_493 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_494 = torch.constant.bool false + %905 = torch.aten.expand %903, %904, %false_494 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %905, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_495 = torch.constant.int 0 + %906 = torch.aten.clone %905, %int0_495 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %906, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_496 = torch.constant.int 4 + %int32_497 = torch.constant.int 32 + %int128_498 = torch.constant.int 128 + %907 = torch.prim.ListConstruct %int4_496, %457, %int32_497, %int128_498 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %908 = torch.aten._unsafe_view %906, %907 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %908, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_499 = torch.constant.int -2 + %909 = torch.aten.unsqueeze %902, %int-2_499 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %909, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_500 = torch.constant.int 4 + %int8_501 = torch.constant.int 8 + %int4_502 = torch.constant.int 4 + %int128_503 = torch.constant.int 128 + %910 = torch.prim.ListConstruct %int4_500, %457, %int8_501, %int4_502, %int128_503 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_504 = torch.constant.bool false + %911 = torch.aten.expand %909, %910, %false_504 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %911, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_505 = torch.constant.int 0 + %912 = torch.aten.clone %911, %int0_505 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %912, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_506 = torch.constant.int 4 - %int8_507 = torch.constant.int 8 - %int4_508 = torch.constant.int 4 - %int128_509 = torch.constant.int 128 - %743 = torch.prim.ListConstruct %int4_506, %742, %int8_507, %int4_508, %int128_509 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_510 = torch.constant.bool false - %744 = torch.aten.expand %741, %743, %false_510 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %744, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_511 = torch.constant.int 0 - %745 = torch.aten.clone %744, %int0_511 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %745, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_512 = torch.constant.int 4 - %int32_513 = torch.constant.int 32 - %int128_514 = torch.constant.int 128 - %746 = torch.prim.ListConstruct %int4_512, %742, %int32_513, %int128_514 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %747 = torch.aten._unsafe_view %745, %746 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %747, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_515 = torch.constant.int -2 - %748 = torch.aten.unsqueeze %740, %int-2_515 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %748, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_516 = torch.constant.int 1 - %749 = torch.aten.size.int %739, %int1_516 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_517 = torch.constant.int 4 - %int8_518 = torch.constant.int 8 - %int4_519 = torch.constant.int 4 - %int128_520 = torch.constant.int 128 - %750 = torch.prim.ListConstruct %int4_517, %749, %int8_518, %int4_519, %int128_520 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_521 = torch.constant.bool false - %751 = torch.aten.expand %748, %750, %false_521 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %751, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_522 = torch.constant.int 0 - %752 = torch.aten.clone %751, %int0_522 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %752, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_523 = torch.constant.int 4 - %int32_524 = torch.constant.int 32 - %int128_525 = torch.constant.int 128 - %753 = torch.prim.ListConstruct %int4_523, %749, %int32_524, %int128_525 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %754 = torch.aten._unsafe_view %752, %753 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %754, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_526 = torch.constant.int 1 - %int2_527 = torch.constant.int 2 - %755 = torch.aten.transpose.int %635, %int1_526, %int2_527 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_528 = torch.constant.int 1 - %int2_529 = torch.constant.int 2 - %756 = torch.aten.transpose.int %747, %int1_528, %int2_529 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %756, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_530 = torch.constant.int 1 - %int2_531 = torch.constant.int 2 - %757 = torch.aten.transpose.int %754, %int1_530, %int2_531 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %757, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_532 = torch.constant.float 0.000000e+00 - %false_533 = torch.constant.bool false - %none_534 = torch.constant.none - %758:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%755, %756, %757, %float0.000000e00_532, %false_533, %368, %none_534) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_535 = torch.constant.int 1 - %int2_536 = torch.constant.int 2 - %759 = torch.aten.transpose.int %758#0, %int1_535, %int2_536 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_537 = torch.constant.int 4 + %int32_507 = torch.constant.int 32 + %int128_508 = torch.constant.int 128 + %913 = torch.prim.ListConstruct %int4_506, %457, %int32_507, %int128_508 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %914 = torch.aten._unsafe_view %912, %913 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %914, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_509 = torch.constant.int 1 + %int2_510 = torch.constant.int 2 + %915 = torch.aten.transpose.int %797, %int1_509, %int2_510 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_511 = torch.constant.int 1 + %int2_512 = torch.constant.int 2 + %916 = torch.aten.transpose.int %908, %int1_511, %int2_512 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %916, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_513 = torch.constant.int 1 + %int2_514 = torch.constant.int 2 + %917 = torch.aten.transpose.int %914, %int1_513, %int2_514 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %917, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_515 = torch.constant.float 0.000000e+00 + %false_516 = torch.constant.bool false + %none_517 = torch.constant.none + %918:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%915, %916, %917, %float0.000000e00_515, %false_516, %470, %none_517) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_518 = torch.constant.int 1 + %int2_519 = torch.constant.int 2 + %919 = torch.aten.transpose.int %918#0, %int1_518, %int2_519 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_520 = torch.constant.int 4 + %int1_521 = torch.constant.int 1 + %int4096_522 = torch.constant.int 4096 + %920 = torch.prim.ListConstruct %int4_520, %int1_521, %int4096_522 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %921 = torch.aten.view %919, %920 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_523 = torch.constant.int -2 + %int-1_524 = torch.constant.int -1 + %922 = torch.aten.transpose.int %25, %int-2_523, %int-1_524 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_525 = torch.constant.int 5 + %923 = torch.prims.convert_element_type %922, %int5_525 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_526 = torch.constant.int 4 + %int4096_527 = torch.constant.int 4096 + %924 = torch.prim.ListConstruct %int4_526, %int4096_527 : (!torch.int, !torch.int) -> !torch.list + %925 = torch.aten.view %921, %924 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %926 = torch.aten.mm %925, %923 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_528 = torch.constant.int 4 + %int1_529 = torch.constant.int 1 + %int4096_530 = torch.constant.int 4096 + %927 = torch.prim.ListConstruct %int4_528, %int1_529, %int4096_530 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %928 = torch.aten.view %926, %927 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_531 = torch.constant.int 1 + %929 = torch.aten.add.Tensor %750, %928, %int1_531 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_532 = torch.constant.int 6 + %930 = torch.prims.convert_element_type %929, %int6_532 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_533 = torch.constant.int 2 + %931 = torch.aten.pow.Tensor_Scalar %930, %int2_533 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_534 = torch.constant.int -1 + %932 = torch.prim.ListConstruct %int-1_534 : (!torch.int) -> !torch.list + %true_535 = torch.constant.bool true + %none_536 = torch.constant.none + %933 = torch.aten.mean.dim %931, %932, %true_535, %none_536 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_537 = torch.constant.float 9.9999997473787516E-6 %int1_538 = torch.constant.int 1 - %int4096_539 = torch.constant.int 4096 - %760 = torch.prim.ListConstruct %int4_537, %int1_538, %int4096_539 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %761 = torch.aten.view %759, %760 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_540 = torch.constant.int -2 - %int-1_541 = torch.constant.int -1 - %762 = torch.aten.transpose.int %18, %int-2_540, %int-1_541 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_542 = torch.constant.int 4 - %int4096_543 = torch.constant.int 4096 - %763 = torch.prim.ListConstruct %int4_542, %int4096_543 : (!torch.int, !torch.int) -> !torch.list - %764 = torch.aten.view %761, %763 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %765 = torch.aten.mm %764, %762 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %934 = torch.aten.add.Scalar %933, %float9.999990e-06_537, %int1_538 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %935 = torch.aten.rsqrt %934 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %936 = torch.aten.mul.Tensor %930, %935 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_539 = torch.constant.int 5 + %937 = torch.prims.convert_element_type %936, %int5_539 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %938 = torch.aten.mul.Tensor %26, %937 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_540 = torch.constant.int 5 + %939 = torch.prims.convert_element_type %938, %int5_540 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_541 = torch.constant.int -2 + %int-1_542 = torch.constant.int -1 + %940 = torch.aten.transpose.int %27, %int-2_541, %int-1_542 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_543 = torch.constant.int 5 + %941 = torch.prims.convert_element_type %940, %int5_543 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> %int4_544 = torch.constant.int 4 - %int1_545 = torch.constant.int 1 - %int4096_546 = torch.constant.int 4096 - %766 = torch.prim.ListConstruct %int4_544, %int1_545, %int4096_546 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %767 = torch.aten.view %765, %766 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int4096_545 = torch.constant.int 4096 + %942 = torch.prim.ListConstruct %int4_544, %int4096_545 : (!torch.int, !torch.int) -> !torch.list + %943 = torch.aten.view %939, %942 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %944 = torch.aten.mm %943, %941 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_546 = torch.constant.int 4 %int1_547 = torch.constant.int 1 - %768 = torch.aten.add.Tensor %595, %767, %int1_547 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_548 = torch.constant.int 6 - %769 = torch.prims.convert_element_type %768, %int6_548 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_549 = torch.constant.int 2 - %770 = torch.aten.pow.Tensor_Scalar %769, %int2_549 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int14336_548 = torch.constant.int 14336 + %945 = torch.prim.ListConstruct %int4_546, %int1_547, %int14336_548 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %946 = torch.aten.view %944, %945 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %947 = torch.aten.silu %946 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_549 = torch.constant.int -2 %int-1_550 = torch.constant.int -1 - %771 = torch.prim.ListConstruct %int-1_550 : (!torch.int) -> !torch.list - %true_551 = torch.constant.bool true - %none_552 = torch.constant.none - %772 = torch.aten.mean.dim %770, %771, %true_551, %none_552 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_553 = torch.constant.float 9.9999997473787516E-6 - %int1_554 = torch.constant.int 1 - %773 = torch.aten.add.Scalar %772, %float9.999990e-06_553, %int1_554 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %774 = torch.aten.rsqrt %773 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %775 = torch.aten.mul.Tensor %769, %774 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_555 = torch.constant.int 5 - %776 = torch.prims.convert_element_type %775, %int5_555 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %777 = torch.aten.mul.Tensor %19, %776 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_556 = torch.constant.int 5 - %778 = torch.prims.convert_element_type %777, %int5_556 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %948 = torch.aten.transpose.int %28, %int-2_549, %int-1_550 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_551 = torch.constant.int 5 + %949 = torch.prims.convert_element_type %948, %int5_551 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_552 = torch.constant.int 4 + %int4096_553 = torch.constant.int 4096 + %950 = torch.prim.ListConstruct %int4_552, %int4096_553 : (!torch.int, !torch.int) -> !torch.list + %951 = torch.aten.view %939, %950 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %952 = torch.aten.mm %951, %949 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_554 = torch.constant.int 4 + %int1_555 = torch.constant.int 1 + %int14336_556 = torch.constant.int 14336 + %953 = torch.prim.ListConstruct %int4_554, %int1_555, %int14336_556 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %954 = torch.aten.view %952, %953 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %955 = torch.aten.mul.Tensor %947, %954 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> %int-2_557 = torch.constant.int -2 %int-1_558 = torch.constant.int -1 - %779 = torch.aten.transpose.int %20, %int-2_557, %int-1_558 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_559 = torch.constant.int 4 - %int4096_560 = torch.constant.int 4096 - %780 = torch.prim.ListConstruct %int4_559, %int4096_560 : (!torch.int, !torch.int) -> !torch.list - %781 = torch.aten.view %778, %780 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %782 = torch.aten.mm %781, %779 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_561 = torch.constant.int 4 - %int1_562 = torch.constant.int 1 - %int14336_563 = torch.constant.int 14336 - %783 = torch.prim.ListConstruct %int4_561, %int1_562, %int14336_563 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %784 = torch.aten.view %782, %783 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %785 = torch.aten.silu %784 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_564 = torch.constant.int -2 - %int-1_565 = torch.constant.int -1 - %786 = torch.aten.transpose.int %21, %int-2_564, %int-1_565 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_566 = torch.constant.int 4 - %int4096_567 = torch.constant.int 4096 - %787 = torch.prim.ListConstruct %int4_566, %int4096_567 : (!torch.int, !torch.int) -> !torch.list - %788 = torch.aten.view %778, %787 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %789 = torch.aten.mm %788, %786 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_568 = torch.constant.int 4 - %int1_569 = torch.constant.int 1 - %int14336_570 = torch.constant.int 14336 - %790 = torch.prim.ListConstruct %int4_568, %int1_569, %int14336_570 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %791 = torch.aten.view %789, %790 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %792 = torch.aten.mul.Tensor %785, %791 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_571 = torch.constant.int -2 - %int-1_572 = torch.constant.int -1 - %793 = torch.aten.transpose.int %22, %int-2_571, %int-1_572 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_573 = torch.constant.int 4 - %int14336_574 = torch.constant.int 14336 - %794 = torch.prim.ListConstruct %int4_573, %int14336_574 : (!torch.int, !torch.int) -> !torch.list - %795 = torch.aten.view %792, %794 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %796 = torch.aten.mm %795, %793 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_575 = torch.constant.int 4 - %int1_576 = torch.constant.int 1 - %int4096_577 = torch.constant.int 4096 - %797 = torch.prim.ListConstruct %int4_575, %int1_576, %int4096_577 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %798 = torch.aten.view %796, %797 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_578 = torch.constant.int 1 - %799 = torch.aten.add.Tensor %768, %798, %int1_578 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_579 = torch.constant.int 6 - %800 = torch.prims.convert_element_type %799, %int6_579 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_580 = torch.constant.int 2 - %801 = torch.aten.pow.Tensor_Scalar %800, %int2_580 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_581 = torch.constant.int -1 - %802 = torch.prim.ListConstruct %int-1_581 : (!torch.int) -> !torch.list - %true_582 = torch.constant.bool true - %none_583 = torch.constant.none - %803 = torch.aten.mean.dim %801, %802, %true_582, %none_583 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_584 = torch.constant.float 9.9999997473787516E-6 - %int1_585 = torch.constant.int 1 - %804 = torch.aten.add.Scalar %803, %float9.999990e-06_584, %int1_585 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %805 = torch.aten.rsqrt %804 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %806 = torch.aten.mul.Tensor %800, %805 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_586 = torch.constant.int 5 - %807 = torch.prims.convert_element_type %806, %int5_586 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %808 = torch.aten.mul.Tensor %23, %807 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_587 = torch.constant.int 5 - %809 = torch.prims.convert_element_type %808, %int5_587 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_588 = torch.constant.int -2 - %int-1_589 = torch.constant.int -1 - %810 = torch.aten.transpose.int %24, %int-2_588, %int-1_589 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_590 = torch.constant.int 4 - %int4096_591 = torch.constant.int 4096 - %811 = torch.prim.ListConstruct %int4_590, %int4096_591 : (!torch.int, !torch.int) -> !torch.list - %812 = torch.aten.view %809, %811 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %813 = torch.aten.mm %812, %810 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_592 = torch.constant.int 4 - %int1_593 = torch.constant.int 1 - %int4096_594 = torch.constant.int 4096 - %814 = torch.prim.ListConstruct %int4_592, %int1_593, %int4096_594 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %815 = torch.aten.view %813, %814 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_595 = torch.constant.int -2 - %int-1_596 = torch.constant.int -1 - %816 = torch.aten.transpose.int %25, %int-2_595, %int-1_596 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_597 = torch.constant.int 4 - %int4096_598 = torch.constant.int 4096 - %817 = torch.prim.ListConstruct %int4_597, %int4096_598 : (!torch.int, !torch.int) -> !torch.list - %818 = torch.aten.view %809, %817 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %819 = torch.aten.mm %818, %816 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %956 = torch.aten.transpose.int %29, %int-2_557, %int-1_558 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_559 = torch.constant.int 5 + %957 = torch.prims.convert_element_type %956, %int5_559 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_560 = torch.constant.int 4 + %int14336_561 = torch.constant.int 14336 + %958 = torch.prim.ListConstruct %int4_560, %int14336_561 : (!torch.int, !torch.int) -> !torch.list + %959 = torch.aten.view %955, %958 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %960 = torch.aten.mm %959, %957 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_562 = torch.constant.int 4 + %int1_563 = torch.constant.int 1 + %int4096_564 = torch.constant.int 4096 + %961 = torch.prim.ListConstruct %int4_562, %int1_563, %int4096_564 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %962 = torch.aten.view %960, %961 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_565 = torch.constant.int 1 + %963 = torch.aten.add.Tensor %929, %962, %int1_565 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_566 = torch.constant.int 6 + %964 = torch.prims.convert_element_type %963, %int6_566 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_567 = torch.constant.int 2 + %965 = torch.aten.pow.Tensor_Scalar %964, %int2_567 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_568 = torch.constant.int -1 + %966 = torch.prim.ListConstruct %int-1_568 : (!torch.int) -> !torch.list + %true_569 = torch.constant.bool true + %none_570 = torch.constant.none + %967 = torch.aten.mean.dim %965, %966, %true_569, %none_570 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_571 = torch.constant.float 9.9999997473787516E-6 + %int1_572 = torch.constant.int 1 + %968 = torch.aten.add.Scalar %967, %float9.999990e-06_571, %int1_572 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %969 = torch.aten.rsqrt %968 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %970 = torch.aten.mul.Tensor %964, %969 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_573 = torch.constant.int 5 + %971 = torch.prims.convert_element_type %970, %int5_573 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %972 = torch.aten.mul.Tensor %30, %971 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_574 = torch.constant.int 5 + %973 = torch.prims.convert_element_type %972, %int5_574 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_575 = torch.constant.int -2 + %int-1_576 = torch.constant.int -1 + %974 = torch.aten.transpose.int %31, %int-2_575, %int-1_576 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_577 = torch.constant.int 5 + %975 = torch.prims.convert_element_type %974, %int5_577 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_578 = torch.constant.int 4 + %int4096_579 = torch.constant.int 4096 + %976 = torch.prim.ListConstruct %int4_578, %int4096_579 : (!torch.int, !torch.int) -> !torch.list + %977 = torch.aten.view %973, %976 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %978 = torch.aten.mm %977, %975 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_580 = torch.constant.int 4 + %int1_581 = torch.constant.int 1 + %int4096_582 = torch.constant.int 4096 + %979 = torch.prim.ListConstruct %int4_580, %int1_581, %int4096_582 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %980 = torch.aten.view %978, %979 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_583 = torch.constant.int -2 + %int-1_584 = torch.constant.int -1 + %981 = torch.aten.transpose.int %32, %int-2_583, %int-1_584 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_585 = torch.constant.int 5 + %982 = torch.prims.convert_element_type %981, %int5_585 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_586 = torch.constant.int 4 + %int4096_587 = torch.constant.int 4096 + %983 = torch.prim.ListConstruct %int4_586, %int4096_587 : (!torch.int, !torch.int) -> !torch.list + %984 = torch.aten.view %973, %983 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %985 = torch.aten.mm %984, %982 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_588 = torch.constant.int 4 + %int1_589 = torch.constant.int 1 + %int1024_590 = torch.constant.int 1024 + %986 = torch.prim.ListConstruct %int4_588, %int1_589, %int1024_590 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %987 = torch.aten.view %985, %986 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_591 = torch.constant.int -2 + %int-1_592 = torch.constant.int -1 + %988 = torch.aten.transpose.int %33, %int-2_591, %int-1_592 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_593 = torch.constant.int 5 + %989 = torch.prims.convert_element_type %988, %int5_593 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_594 = torch.constant.int 4 + %int4096_595 = torch.constant.int 4096 + %990 = torch.prim.ListConstruct %int4_594, %int4096_595 : (!torch.int, !torch.int) -> !torch.list + %991 = torch.aten.view %973, %990 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %992 = torch.aten.mm %991, %989 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_596 = torch.constant.int 4 + %int1_597 = torch.constant.int 1 + %int1024_598 = torch.constant.int 1024 + %993 = torch.prim.ListConstruct %int4_596, %int1_597, %int1024_598 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %994 = torch.aten.view %992, %993 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> %int4_599 = torch.constant.int 4 %int1_600 = torch.constant.int 1 - %int1024_601 = torch.constant.int 1024 - %820 = torch.prim.ListConstruct %int4_599, %int1_600, %int1024_601 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %821 = torch.aten.view %819, %820 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_602 = torch.constant.int -2 - %int-1_603 = torch.constant.int -1 - %822 = torch.aten.transpose.int %26, %int-2_602, %int-1_603 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_604 = torch.constant.int 4 - %int4096_605 = torch.constant.int 4096 - %823 = torch.prim.ListConstruct %int4_604, %int4096_605 : (!torch.int, !torch.int) -> !torch.list - %824 = torch.aten.view %809, %823 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %825 = torch.aten.mm %824, %822 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_606 = torch.constant.int 4 - %int1_607 = torch.constant.int 1 - %int1024_608 = torch.constant.int 1024 - %826 = torch.prim.ListConstruct %int4_606, %int1_607, %int1024_608 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %827 = torch.aten.view %825, %826 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_609 = torch.constant.int 4 - %int1_610 = torch.constant.int 1 - %int32_611 = torch.constant.int 32 - %int128_612 = torch.constant.int 128 - %828 = torch.prim.ListConstruct %int4_609, %int1_610, %int32_611, %int128_612 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %829 = torch.aten.view %815, %828 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_613 = torch.constant.int 4 - %int1_614 = torch.constant.int 1 - %int8_615 = torch.constant.int 8 - %int128_616 = torch.constant.int 128 - %830 = torch.prim.ListConstruct %int4_613, %int1_614, %int8_615, %int128_616 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %831 = torch.aten.view %821, %830 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_617 = torch.constant.int 4 - %int1_618 = torch.constant.int 1 - %int8_619 = torch.constant.int 8 - %int128_620 = torch.constant.int 128 - %832 = torch.prim.ListConstruct %int4_617, %int1_618, %int8_619, %int128_620 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %833 = torch.aten.view %827, %832 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_621 = torch.constant.int 6 - %834 = torch.prims.convert_element_type %829, %int6_621 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %835 = torch_c.to_builtin_tensor %834 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %836 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %837 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%835, %836) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %838 = torch_c.from_builtin_tensor %837 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_622 = torch.constant.int 5 - %839 = torch.prims.convert_element_type %838, %int5_622 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_623 = torch.constant.int 6 - %840 = torch.prims.convert_element_type %831, %int6_623 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %841 = torch_c.to_builtin_tensor %840 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %842 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %843 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%841, %842) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %844 = torch_c.from_builtin_tensor %843 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_624 = torch.constant.int 5 - %845 = torch.prims.convert_element_type %844, %int5_624 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_625 = torch.constant.int 32 - %846 = torch.aten.floor_divide.Scalar %arg2, %int32_625 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_626 = torch.constant.int 1 - %847 = torch.aten.unsqueeze %846, %int1_626 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_627 = torch.constant.int 1 - %false_628 = torch.constant.bool false - %848 = torch.aten.gather %arg3, %int1_627, %847, %false_628 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_629 = torch.constant.int 32 - %849 = torch.aten.remainder.Scalar %arg2, %int32_629 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int32_601 = torch.constant.int 32 + %int128_602 = torch.constant.int 128 + %995 = torch.prim.ListConstruct %int4_599, %int1_600, %int32_601, %int128_602 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %996 = torch.aten.view %980, %995 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_603 = torch.constant.int 4 + %int1_604 = torch.constant.int 1 + %int8_605 = torch.constant.int 8 + %int128_606 = torch.constant.int 128 + %997 = torch.prim.ListConstruct %int4_603, %int1_604, %int8_605, %int128_606 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %998 = torch.aten.view %987, %997 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_607 = torch.constant.int 4 + %int1_608 = torch.constant.int 1 + %int8_609 = torch.constant.int 8 + %int128_610 = torch.constant.int 128 + %999 = torch.prim.ListConstruct %int4_607, %int1_608, %int8_609, %int128_610 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1000 = torch.aten.view %994, %999 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_611 = torch.constant.int 1 + %int2_612 = torch.constant.int 2 + %1001 = torch.aten.transpose.int %996, %int1_611, %int2_612 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %1002 = torch.aten.mul.Tensor %1001, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_613 = torch.constant.int 3 + %int0_614 = torch.constant.int 0 + %int64_615 = torch.constant.int 64 + %int1_616 = torch.constant.int 1 + %1003 = torch.aten.slice.Tensor %1001, %int3_613, %int0_614, %int64_615, %int1_616 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_617 = torch.constant.int 3 + %int64_618 = torch.constant.int 64 + %int9223372036854775807_619 = torch.constant.int 9223372036854775807 + %int1_620 = torch.constant.int 1 + %1004 = torch.aten.slice.Tensor %1001, %int3_617, %int64_618, %int9223372036854775807_619, %int1_620 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %1005 = torch.aten.neg %1004 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %1006 = torch.prim.ListConstruct %1005, %1003 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_621 = torch.constant.int -1 + %1007 = torch.aten.cat %1006, %int-1_621 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %1008 = torch.aten.mul.Tensor %1007, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_622 = torch.constant.int 1 + %1009 = torch.aten.add.Tensor %1002, %1008, %int1_622 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_623 = torch.constant.int 1 + %int2_624 = torch.constant.int 2 + %1010 = torch.aten.transpose.int %1009, %int1_623, %int2_624 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_625 = torch.constant.int 1 + %int2_626 = torch.constant.int 2 + %1011 = torch.aten.transpose.int %998, %int1_625, %int2_626 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %1012 = torch.aten.mul.Tensor %1011, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_627 = torch.constant.int 3 + %int0_628 = torch.constant.int 0 + %int64_629 = torch.constant.int 64 %int1_630 = torch.constant.int 1 - %850 = torch.aten.unsqueeze %849, %int1_630 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_631 = torch.constant.none - %851 = torch.aten.clone %27, %none_631 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_632 = torch.constant.int 0 - %852 = torch.aten.unsqueeze %851, %int0_632 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_633 = torch.constant.int 4 + %1013 = torch.aten.slice.Tensor %1011, %int3_627, %int0_628, %int64_629, %int1_630 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_631 = torch.constant.int 3 + %int64_632 = torch.constant.int 64 + %int9223372036854775807_633 = torch.constant.int 9223372036854775807 %int1_634 = torch.constant.int 1 - %853 = torch.prim.ListConstruct %int4_633, %int1_634 : (!torch.int, !torch.int) -> !torch.list - %int1_635 = torch.constant.int 1 + %1014 = torch.aten.slice.Tensor %1011, %int3_631, %int64_632, %int9223372036854775807_633, %int1_634 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %1015 = torch.aten.neg %1014 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %1016 = torch.prim.ListConstruct %1015, %1013 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_635 = torch.constant.int -1 + %1017 = torch.aten.cat %1016, %int-1_635 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %1018 = torch.aten.mul.Tensor %1017, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> %int1_636 = torch.constant.int 1 - %854 = torch.prim.ListConstruct %int1_635, %int1_636 : (!torch.int, !torch.int) -> !torch.list - %int4_637 = torch.constant.int 4 - %int0_638 = torch.constant.int 0 - %cpu_639 = torch.constant.device "cpu" - %false_640 = torch.constant.bool false - %855 = torch.aten.empty_strided %853, %854, %int4_637, %int0_638, %cpu_639, %false_640 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int2_641 = torch.constant.int 2 - %856 = torch.aten.fill.Scalar %855, %int2_641 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_642 = torch.constant.int 4 - %int1_643 = torch.constant.int 1 - %857 = torch.prim.ListConstruct %int4_642, %int1_643 : (!torch.int, !torch.int) -> !torch.list - %858 = torch.aten.repeat %852, %857 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_644 = torch.constant.int 32 - %859 = torch.aten.mul.Scalar %848, %int32_644 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %1019 = torch.aten.add.Tensor %1012, %1018, %int1_636 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_637 = torch.constant.int 1 + %int2_638 = torch.constant.int 2 + %1020 = torch.aten.transpose.int %1019, %int1_637, %int2_638 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_639 = torch.constant.int 32 + %1021 = torch.aten.floor_divide.Scalar %arg2, %int32_639 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_640 = torch.constant.int 1 + %1022 = torch.aten.unsqueeze %1021, %int1_640 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_641 = torch.constant.int 1 + %false_642 = torch.constant.bool false + %1023 = torch.aten.gather %arg3, %int1_641, %1022, %false_642 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_643 = torch.constant.int 4 + %int1_644 = torch.constant.int 1 %int1_645 = torch.constant.int 1 - %860 = torch.aten.add.Tensor %859, %856, %int1_645 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_646 = torch.constant.int 2 - %861 = torch.aten.mul.Scalar %860, %int2_646 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_647 = torch.constant.int 1 - %862 = torch.aten.add.Tensor %861, %858, %int1_647 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_648 = torch.constant.int 32 - %863 = torch.aten.mul.Scalar %862, %int32_648 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %1024 = torch.prim.ListConstruct %int4_643, %int1_644, %int1_645 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1025 = torch.aten.view %1023, %1024 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_646 = torch.constant.int 32 + %1026 = torch.aten.remainder.Scalar %arg2, %int32_646 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_647 = torch.constant.int 4 + %int1_648 = torch.constant.int 1 %int1_649 = torch.constant.int 1 - %864 = torch.aten.add.Tensor %863, %850, %int1_649 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_650 = torch.constant.int 32 - %int2_651 = torch.constant.int 2 - %int32_652 = torch.constant.int 32 - %int8_653 = torch.constant.int 8 - %int128_654 = torch.constant.int 128 - %865 = torch.prim.ListConstruct %437, %int32_650, %int2_651, %int32_652, %int8_653, %int128_654 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %866 = torch.aten.view %702, %865 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %866, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_655 = torch.constant.int 32 - %867 = torch.aten.mul.int %437, %int32_655 : !torch.int, !torch.int -> !torch.int - %int2_656 = torch.constant.int 2 - %868 = torch.aten.mul.int %867, %int2_656 : !torch.int, !torch.int -> !torch.int - %int32_657 = torch.constant.int 32 - %869 = torch.aten.mul.int %868, %int32_657 : !torch.int, !torch.int -> !torch.int - %int8_658 = torch.constant.int 8 - %int128_659 = torch.constant.int 128 - %870 = torch.prim.ListConstruct %869, %int8_658, %int128_659 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %871 = torch.aten.view %866, %870 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %871, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %872 = torch.prim.ListConstruct %864 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_660 = torch.constant.bool false - %873 = torch.aten.index_put %871, %872, %845, %false_660 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %873, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_661 = torch.constant.int 32 - %int2_662 = torch.constant.int 2 - %int32_663 = torch.constant.int 32 - %int8_664 = torch.constant.int 8 - %int128_665 = torch.constant.int 128 - %874 = torch.prim.ListConstruct %437, %int32_661, %int2_662, %int32_663, %int8_664, %int128_665 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %875 = torch.aten.view %873, %874 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %875, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_666 = torch.constant.int 2097152 - %876 = torch.prim.ListConstruct %437, %int2097152_666 : (!torch.int, !torch.int) -> !torch.list - %877 = torch.aten.view %875, %876 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %877, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_667 = torch.constant.int 32 - %int2_668 = torch.constant.int 2 + %1027 = torch.prim.ListConstruct %int4_647, %int1_648, %int1_649 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1028 = torch.aten.view %1026, %1027 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_650 = torch.constant.int 8 + %none_651 = torch.constant.none + %none_652 = torch.constant.none + %cpu_653 = torch.constant.device "cpu" + %false_654 = torch.constant.bool false + %1029 = torch.aten.arange %int8_650, %none_651, %none_652, %cpu_653, %false_654 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_655 = torch.constant.int 1 + %int1_656 = torch.constant.int 1 + %int8_657 = torch.constant.int 8 + %1030 = torch.prim.ListConstruct %int1_655, %int1_656, %int8_657 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1031 = torch.aten.view %1029, %1030 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_658 = torch.constant.none + %1032 = torch.aten.clone %34, %none_658 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1033 = torch.aten.detach %1032 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1034 = torch.aten.detach %1033 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1035 = torch.aten.detach %1034 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_659 = torch.constant.int 1 + %int1_660 = torch.constant.int 1 + %int1_661 = torch.constant.int 1 + %1036 = torch.prim.ListConstruct %int1_659, %int1_660, %int1_661 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1037 = torch.aten.view %1035, %1036 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_662 = torch.constant.int 32 + %1038 = torch.aten.mul.Scalar %1025, %int32_662 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_663 = torch.constant.int 2 + %int1_664 = torch.constant.int 1 + %1039 = torch.aten.add.Scalar %1038, %int2_663, %int1_664 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_665 = torch.constant.int 2 + %1040 = torch.aten.mul.Scalar %1039, %int2_665 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_666 = torch.constant.int 1 + %1041 = torch.aten.add.Tensor %1040, %1037, %int1_666 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_667 = torch.constant.int 8 + %1042 = torch.aten.mul.Scalar %1041, %int8_667 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_668 = torch.constant.int 1 + %1043 = torch.aten.add.Tensor %1042, %1031, %int1_668 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int32_669 = torch.constant.int 32 - %int8_670 = torch.constant.int 8 - %int128_671 = torch.constant.int 128 - %878 = torch.prim.ListConstruct %437, %int32_667, %int2_668, %int32_669, %int8_670, %int128_671 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %879 = torch.aten.view %877, %878 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %879, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_672 = torch.constant.int 8 - %int128_673 = torch.constant.int 128 - %880 = torch.prim.ListConstruct %869, %int8_672, %int128_673 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %881 = torch.aten.view %879, %880 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %881, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_674 = torch.constant.int 32 - %882 = torch.aten.floor_divide.Scalar %arg2, %int32_674 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_675 = torch.constant.int 1 - %883 = torch.aten.unsqueeze %882, %int1_675 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_676 = torch.constant.int 1 - %false_677 = torch.constant.bool false - %884 = torch.aten.gather %arg3, %int1_676, %883, %false_677 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_678 = torch.constant.int 32 - %885 = torch.aten.remainder.Scalar %arg2, %int32_678 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_679 = torch.constant.int 1 - %886 = torch.aten.unsqueeze %885, %int1_679 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_680 = torch.constant.none - %887 = torch.aten.clone %28, %none_680 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_681 = torch.constant.int 0 - %888 = torch.aten.unsqueeze %887, %int0_681 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_682 = torch.constant.int 4 - %int1_683 = torch.constant.int 1 - %889 = torch.prim.ListConstruct %int4_682, %int1_683 : (!torch.int, !torch.int) -> !torch.list - %int1_684 = torch.constant.int 1 - %int1_685 = torch.constant.int 1 - %890 = torch.prim.ListConstruct %int1_684, %int1_685 : (!torch.int, !torch.int) -> !torch.list - %int4_686 = torch.constant.int 4 - %int0_687 = torch.constant.int 0 - %cpu_688 = torch.constant.device "cpu" - %false_689 = torch.constant.bool false - %891 = torch.aten.empty_strided %889, %890, %int4_686, %int0_687, %cpu_688, %false_689 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int2_690 = torch.constant.int 2 - %892 = torch.aten.fill.Scalar %891, %int2_690 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_691 = torch.constant.int 4 + %1044 = torch.aten.mul.Scalar %1043, %int32_669 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_670 = torch.constant.int 1 + %1045 = torch.aten.add.Tensor %1044, %1028, %int1_670 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_671 = torch.constant.int 5 + %1046 = torch.prims.convert_element_type %1020, %int5_671 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_672 = torch.constant.int 32 + %int2_673 = torch.constant.int 2 + %int8_674 = torch.constant.int 8 + %int32_675 = torch.constant.int 32 + %int128_676 = torch.constant.int 128 + %1047 = torch.prim.ListConstruct %456, %int32_672, %int2_673, %int8_674, %int32_675, %int128_676 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1048 = torch.aten.view %868, %1047 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1048, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_677 = torch.constant.int 128 + %1049 = torch.prim.ListConstruct %596, %int128_677 : (!torch.int, !torch.int) -> !torch.list + %1050 = torch.aten.view %1048, %1049 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1050, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %1051 = torch.prim.ListConstruct %1045 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_678 = torch.constant.bool false + %1052 = torch.aten.index_put %1050, %1051, %1046, %false_678 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1052, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_679 = torch.constant.int 32 + %int2_680 = torch.constant.int 2 + %int8_681 = torch.constant.int 8 + %int32_682 = torch.constant.int 32 + %int128_683 = torch.constant.int 128 + %1053 = torch.prim.ListConstruct %456, %int32_679, %int2_680, %int8_681, %int32_682, %int128_683 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1054 = torch.aten.view %1052, %1053 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1054, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_684 = torch.constant.int 2097152 + %1055 = torch.prim.ListConstruct %456, %int2097152_684 : (!torch.int, !torch.int) -> !torch.list + %1056 = torch.aten.view %1054, %1055 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1056, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_685 = torch.constant.int 32 + %int2_686 = torch.constant.int 2 + %int8_687 = torch.constant.int 8 + %int32_688 = torch.constant.int 32 + %int128_689 = torch.constant.int 128 + %1057 = torch.prim.ListConstruct %456, %int32_685, %int2_686, %int8_687, %int32_688, %int128_689 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1058 = torch.aten.view %1056, %1057 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1058, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_690 = torch.constant.int 128 + %1059 = torch.prim.ListConstruct %596, %int128_690 : (!torch.int, !torch.int) -> !torch.list + %1060 = torch.aten.view %1058, %1059 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1060, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_691 = torch.constant.none + %1061 = torch.aten.clone %35, %none_691 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1062 = torch.aten.detach %1061 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1063 = torch.aten.detach %1062 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1064 = torch.aten.detach %1063 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %int1_692 = torch.constant.int 1 - %893 = torch.prim.ListConstruct %int4_691, %int1_692 : (!torch.int, !torch.int) -> !torch.list - %894 = torch.aten.repeat %888, %893 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_693 = torch.constant.int 32 - %895 = torch.aten.mul.Scalar %884, %int32_693 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_693 = torch.constant.int 1 %int1_694 = torch.constant.int 1 - %896 = torch.aten.add.Tensor %895, %892, %int1_694 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_695 = torch.constant.int 2 - %897 = torch.aten.mul.Scalar %896, %int2_695 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_696 = torch.constant.int 1 - %898 = torch.aten.add.Tensor %897, %894, %int1_696 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_697 = torch.constant.int 32 - %899 = torch.aten.mul.Scalar %898, %int32_697 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_698 = torch.constant.int 1 - %900 = torch.aten.add.Tensor %899, %886, %int1_698 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %901 = torch.prim.ListConstruct %900 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_699 = torch.constant.bool false - %902 = torch.aten.index_put %881, %901, %833, %false_699 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %902, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_700 = torch.constant.int 32 - %int2_701 = torch.constant.int 2 + %1065 = torch.prim.ListConstruct %int1_692, %int1_693, %int1_694 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1066 = torch.aten.view %1064, %1065 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_695 = torch.constant.int 32 + %1067 = torch.aten.mul.Scalar %1025, %int32_695 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_696 = torch.constant.int 2 + %int1_697 = torch.constant.int 1 + %1068 = torch.aten.add.Scalar %1067, %int2_696, %int1_697 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_698 = torch.constant.int 2 + %1069 = torch.aten.mul.Scalar %1068, %int2_698 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_699 = torch.constant.int 1 + %1070 = torch.aten.add.Tensor %1069, %1066, %int1_699 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_700 = torch.constant.int 8 + %1071 = torch.aten.mul.Scalar %1070, %int8_700 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_701 = torch.constant.int 1 + %1072 = torch.aten.add.Tensor %1071, %1031, %int1_701 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int32_702 = torch.constant.int 32 - %int8_703 = torch.constant.int 8 - %int128_704 = torch.constant.int 128 - %903 = torch.prim.ListConstruct %437, %int32_700, %int2_701, %int32_702, %int8_703, %int128_704 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %904 = torch.aten.view %902, %903 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %904, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_705 = torch.constant.int 2097152 - %905 = torch.prim.ListConstruct %437, %int2097152_705 : (!torch.int, !torch.int) -> !torch.list - %906 = torch.aten.view %904, %905 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %906, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_706 = torch.constant.int 4 - %907 = torch.prim.ListConstruct %int4_706, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_707 = torch.constant.int 1 - %908 = torch.prim.ListConstruct %358, %int1_707 : (!torch.int, !torch.int) -> !torch.list - %int4_708 = torch.constant.int 4 - %int0_709 = torch.constant.int 0 - %cpu_710 = torch.constant.device "cpu" - %false_711 = torch.constant.bool false - %909 = torch.aten.empty_strided %907, %908, %int4_708, %int0_709, %cpu_710, %false_711 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %909, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int2_712 = torch.constant.int 2 - %910 = torch.aten.fill.Scalar %909, %int2_712 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %910, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_713 = torch.constant.int 32 - %911 = torch.aten.mul.Scalar %arg3, %int32_713 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %911, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_714 = torch.constant.int 1 - %912 = torch.aten.add.Tensor %911, %910, %int1_714 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %912, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_715 = torch.constant.int 4 - %913 = torch.aten.mul.int %int4_715, %358 : !torch.int, !torch.int -> !torch.int - %914 = torch.prim.ListConstruct %913 : (!torch.int) -> !torch.list - %915 = torch.aten.view %912, %914 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %915, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_716 = torch.constant.int 32 - %int2_717 = torch.constant.int 2 + %1073 = torch.aten.mul.Scalar %1072, %int32_702 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_703 = torch.constant.int 1 + %1074 = torch.aten.add.Tensor %1073, %1028, %int1_703 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_704 = torch.constant.int 5 + %1075 = torch.prims.convert_element_type %1000, %int5_704 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %1076 = torch.prim.ListConstruct %1074 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_705 = torch.constant.bool false + %1077 = torch.aten.index_put %1060, %1076, %1075, %false_705 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1077, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_706 = torch.constant.int 32 + %int2_707 = torch.constant.int 2 + %int8_708 = torch.constant.int 8 + %int32_709 = torch.constant.int 32 + %int128_710 = torch.constant.int 128 + %1078 = torch.prim.ListConstruct %456, %int32_706, %int2_707, %int8_708, %int32_709, %int128_710 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1079 = torch.aten.view %1077, %1078 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1079, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_711 = torch.constant.int 2097152 + %1080 = torch.prim.ListConstruct %456, %int2097152_711 : (!torch.int, !torch.int) -> !torch.list + %1081 = torch.aten.view %1079, %1080 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1081, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_712 = torch.constant.none + %1082 = torch.aten.clone %36, %none_712 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1083 = torch.aten.detach %1082 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1084 = torch.aten.detach %1083 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1085 = torch.aten.detach %1084 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_713 = torch.constant.none + %1086 = torch.aten.clone %37, %none_713 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1087 = torch.aten.detach %1086 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1088 = torch.aten.detach %1087 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1089 = torch.aten.detach %1088 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_714 = torch.constant.none + %1090 = torch.aten.clone %38, %none_714 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1091 = torch.aten.detach %1090 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1092 = torch.aten.detach %1091 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1093 = torch.aten.detach %1092 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_715 = torch.constant.int 32 + %int2_716 = torch.constant.int 2 + %int8_717 = torch.constant.int 8 %int32_718 = torch.constant.int 32 - %int8_719 = torch.constant.int 8 - %int128_720 = torch.constant.int 128 - %916 = torch.prim.ListConstruct %437, %int32_716, %int2_717, %int32_718, %int8_719, %int128_720 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %917 = torch.aten.view %906, %916 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %917, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_721 = torch.constant.int 32 - %918 = torch.aten.mul.int %437, %int32_721 : !torch.int, !torch.int -> !torch.int - %int2_722 = torch.constant.int 2 - %int32_723 = torch.constant.int 32 - %int8_724 = torch.constant.int 8 - %int128_725 = torch.constant.int 128 - %919 = torch.prim.ListConstruct %918, %int2_722, %int32_723, %int8_724, %int128_725 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %920 = torch.aten.view %917, %919 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %920, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> + %int128_719 = torch.constant.int 128 + %1094 = torch.prim.ListConstruct %456, %int32_715, %int2_716, %int8_717, %int32_718, %int128_719 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1095 = torch.aten.view %1081, %1094 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1095, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %1096 = torch_c.to_builtin_tensor %1095 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %1097 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_720 = tensor.cast %1097 : tensor<4x?xi64> to tensor + %1098 = torch_c.to_builtin_tensor %1085 : !torch.vtensor<[],si64> -> tensor + %1099 = torch_c.to_builtin_tensor %1089 : !torch.vtensor<[],si64> -> tensor + %1100 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%1096, %cast_720, %1098, %1099) : (tensor, tensor, tensor, tensor) -> tensor + %cast_721 = tensor.cast %1100 : tensor to tensor<4x?x8x32x128xf16> + %1101 = torch_c.from_builtin_tensor %cast_721 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %1101, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %1102 = torch_c.to_builtin_tensor %1095 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %1103 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_722 = tensor.cast %1103 : tensor<4x?xi64> to tensor + %1104 = torch_c.to_builtin_tensor %1085 : !torch.vtensor<[],si64> -> tensor + %1105 = torch_c.to_builtin_tensor %1093 : !torch.vtensor<[],si64> -> tensor + %1106 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%1102, %cast_722, %1104, %1105) : (tensor, tensor, tensor, tensor) -> tensor + %cast_723 = tensor.cast %1106 : tensor to tensor<4x?x8x32x128xf16> + %1107 = torch_c.from_builtin_tensor %cast_723 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %1107, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_724 = torch.constant.int 2 + %int3_725 = torch.constant.int 3 + %1108 = torch.aten.transpose.int %1101, %int2_724, %int3_725 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1108, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int0_726 = torch.constant.int 0 - %921 = torch.aten.index_select %920, %int0_726, %915 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %921, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> + %1109 = torch.aten.clone %1108, %int0_726 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1109, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int4_727 = torch.constant.int 4 - %int2_728 = torch.constant.int 2 - %int32_729 = torch.constant.int 32 - %int8_730 = torch.constant.int 8 - %int128_731 = torch.constant.int 128 - %922 = torch.prim.ListConstruct %int4_727, %358, %int2_728, %int32_729, %int8_730, %int128_731 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %923 = torch.aten.view %921, %922 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %923, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> + %int8_728 = torch.constant.int 8 + %int128_729 = torch.constant.int 128 + %1110 = torch.prim.ListConstruct %int4_727, %457, %int8_728, %int128_729 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1111 = torch.aten._unsafe_view %1109, %1110 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1111, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_730 = torch.constant.int 2 + %int3_731 = torch.constant.int 3 + %1112 = torch.aten.transpose.int %1107, %int2_730, %int3_731 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1112, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int0_732 = torch.constant.int 0 - %int0_733 = torch.constant.int 0 - %int9223372036854775807_734 = torch.constant.int 9223372036854775807 - %int1_735 = torch.constant.int 1 - %924 = torch.aten.slice.Tensor %923, %int0_732, %int0_733, %int9223372036854775807_734, %int1_735 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %924, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_736 = torch.constant.int 1 - %int0_737 = torch.constant.int 0 - %int9223372036854775807_738 = torch.constant.int 9223372036854775807 - %int1_739 = torch.constant.int 1 - %925 = torch.aten.slice.Tensor %924, %int1_736, %int0_737, %int9223372036854775807_738, %int1_739 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %925, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_740 = torch.constant.int 2 - %int0_741 = torch.constant.int 0 - %926 = torch.aten.select.int %925, %int2_740, %int0_741 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %926, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_742 = torch.constant.int 32 - %927 = torch.aten.mul.int %358, %int32_742 : !torch.int, !torch.int -> !torch.int - %int2_743 = torch.constant.int 2 - %int0_744 = torch.constant.int 0 - %int1_745 = torch.constant.int 1 - %928 = torch.aten.slice.Tensor %926, %int2_743, %int0_744, %927, %int1_745 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %928, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_746 = torch.constant.int 0 - %929 = torch.aten.clone %928, %int0_746 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %929, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_747 = torch.constant.int 1 - %930 = torch.aten.size.int %925, %int1_747 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_748 = torch.constant.int 32 - %931 = torch.aten.mul.int %930, %int32_748 : !torch.int, !torch.int -> !torch.int + %1113 = torch.aten.clone %1112, %int0_732 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1113, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_733 = torch.constant.int 4 + %int8_734 = torch.constant.int 8 + %int128_735 = torch.constant.int 128 + %1114 = torch.prim.ListConstruct %int4_733, %457, %int8_734, %int128_735 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1115 = torch.aten._unsafe_view %1113, %1114 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1115, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_736 = torch.constant.int -2 + %1116 = torch.aten.unsqueeze %1111, %int-2_736 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1116, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_737 = torch.constant.int 4 + %int8_738 = torch.constant.int 8 + %int4_739 = torch.constant.int 4 + %int128_740 = torch.constant.int 128 + %1117 = torch.prim.ListConstruct %int4_737, %457, %int8_738, %int4_739, %int128_740 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_741 = torch.constant.bool false + %1118 = torch.aten.expand %1116, %1117, %false_741 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1118, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_742 = torch.constant.int 0 + %1119 = torch.aten.clone %1118, %int0_742 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1119, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_743 = torch.constant.int 4 + %int32_744 = torch.constant.int 32 + %int128_745 = torch.constant.int 128 + %1120 = torch.prim.ListConstruct %int4_743, %457, %int32_744, %int128_745 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1121 = torch.aten._unsafe_view %1119, %1120 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1121, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_746 = torch.constant.int -2 + %1122 = torch.aten.unsqueeze %1115, %int-2_746 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1122, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_747 = torch.constant.int 4 + %int8_748 = torch.constant.int 8 %int4_749 = torch.constant.int 4 - %int8_750 = torch.constant.int 8 - %int128_751 = torch.constant.int 128 - %932 = torch.prim.ListConstruct %int4_749, %931, %int8_750, %int128_751 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %933 = torch.aten._unsafe_view %929, %932 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %933, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int128_750 = torch.constant.int 128 + %1123 = torch.prim.ListConstruct %int4_747, %457, %int8_748, %int4_749, %int128_750 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_751 = torch.constant.bool false + %1124 = torch.aten.expand %1122, %1123, %false_751 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1124, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int0_752 = torch.constant.int 0 - %int0_753 = torch.constant.int 0 - %int9223372036854775807_754 = torch.constant.int 9223372036854775807 - %int1_755 = torch.constant.int 1 - %934 = torch.aten.slice.Tensor %933, %int0_752, %int0_753, %int9223372036854775807_754, %int1_755 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %934, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_756 = torch.constant.int 0 - %int0_757 = torch.constant.int 0 - %int9223372036854775807_758 = torch.constant.int 9223372036854775807 - %int1_759 = torch.constant.int 1 - %935 = torch.aten.slice.Tensor %923, %int0_756, %int0_757, %int9223372036854775807_758, %int1_759 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %935, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> + %1125 = torch.aten.clone %1124, %int0_752 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1125, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_753 = torch.constant.int 4 + %int32_754 = torch.constant.int 32 + %int128_755 = torch.constant.int 128 + %1126 = torch.prim.ListConstruct %int4_753, %457, %int32_754, %int128_755 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1127 = torch.aten._unsafe_view %1125, %1126 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1127, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_756 = torch.constant.int 1 + %int2_757 = torch.constant.int 2 + %1128 = torch.aten.transpose.int %1010, %int1_756, %int2_757 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_758 = torch.constant.int 1 + %int2_759 = torch.constant.int 2 + %1129 = torch.aten.transpose.int %1121, %int1_758, %int2_759 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1129, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_760 = torch.constant.int 1 - %int0_761 = torch.constant.int 0 - %int9223372036854775807_762 = torch.constant.int 9223372036854775807 - %int1_763 = torch.constant.int 1 - %936 = torch.aten.slice.Tensor %935, %int1_760, %int0_761, %int9223372036854775807_762, %int1_763 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %936, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_764 = torch.constant.int 2 + %int2_761 = torch.constant.int 2 + %1130 = torch.aten.transpose.int %1127, %int1_760, %int2_761 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1130, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_762 = torch.constant.float 0.000000e+00 + %false_763 = torch.constant.bool false + %none_764 = torch.constant.none + %1131:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1128, %1129, %1130, %float0.000000e00_762, %false_763, %470, %none_764) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) %int1_765 = torch.constant.int 1 - %937 = torch.aten.select.int %936, %int2_764, %int1_765 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %937, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int2_766 = torch.constant.int 2 - %int0_767 = torch.constant.int 0 + %1132 = torch.aten.transpose.int %1131#0, %int1_765, %int2_766 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_767 = torch.constant.int 4 %int1_768 = torch.constant.int 1 - %938 = torch.aten.slice.Tensor %937, %int2_766, %int0_767, %927, %int1_768 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %938, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_769 = torch.constant.int 0 - %939 = torch.aten.clone %938, %int0_769 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %939, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_770 = torch.constant.int 1 - %940 = torch.aten.size.int %936, %int1_770 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_771 = torch.constant.int 32 - %941 = torch.aten.mul.int %940, %int32_771 : !torch.int, !torch.int -> !torch.int - %int4_772 = torch.constant.int 4 - %int8_773 = torch.constant.int 8 - %int128_774 = torch.constant.int 128 - %942 = torch.prim.ListConstruct %int4_772, %941, %int8_773, %int128_774 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %943 = torch.aten._unsafe_view %939, %942 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %943, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_775 = torch.constant.int 0 - %int0_776 = torch.constant.int 0 - %int9223372036854775807_777 = torch.constant.int 9223372036854775807 + %int4096_769 = torch.constant.int 4096 + %1133 = torch.prim.ListConstruct %int4_767, %int1_768, %int4096_769 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1134 = torch.aten.view %1132, %1133 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_770 = torch.constant.int -2 + %int-1_771 = torch.constant.int -1 + %1135 = torch.aten.transpose.int %39, %int-2_770, %int-1_771 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_772 = torch.constant.int 5 + %1136 = torch.prims.convert_element_type %1135, %int5_772 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_773 = torch.constant.int 4 + %int4096_774 = torch.constant.int 4096 + %1137 = torch.prim.ListConstruct %int4_773, %int4096_774 : (!torch.int, !torch.int) -> !torch.list + %1138 = torch.aten.view %1134, %1137 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1139 = torch.aten.mm %1138, %1136 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_775 = torch.constant.int 4 + %int1_776 = torch.constant.int 1 + %int4096_777 = torch.constant.int 4096 + %1140 = torch.prim.ListConstruct %int4_775, %int1_776, %int4096_777 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1141 = torch.aten.view %1139, %1140 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_778 = torch.constant.int 1 - %944 = torch.aten.slice.Tensor %943, %int0_775, %int0_776, %int9223372036854775807_777, %int1_778 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %944, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_779 = torch.constant.int -2 - %945 = torch.aten.unsqueeze %934, %int-2_779 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %945, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_780 = torch.constant.int 1 - %946 = torch.aten.size.int %933, %int1_780 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_781 = torch.constant.int 4 - %int8_782 = torch.constant.int 8 - %int4_783 = torch.constant.int 4 - %int128_784 = torch.constant.int 128 - %947 = torch.prim.ListConstruct %int4_781, %946, %int8_782, %int4_783, %int128_784 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_785 = torch.constant.bool false - %948 = torch.aten.expand %945, %947, %false_785 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %948, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_786 = torch.constant.int 0 - %949 = torch.aten.clone %948, %int0_786 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %949, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_787 = torch.constant.int 4 - %int32_788 = torch.constant.int 32 - %int128_789 = torch.constant.int 128 - %950 = torch.prim.ListConstruct %int4_787, %946, %int32_788, %int128_789 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %951 = torch.aten._unsafe_view %949, %950 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %951, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_790 = torch.constant.int -2 - %952 = torch.aten.unsqueeze %944, %int-2_790 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %952, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_791 = torch.constant.int 1 - %953 = torch.aten.size.int %943, %int1_791 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_792 = torch.constant.int 4 - %int8_793 = torch.constant.int 8 - %int4_794 = torch.constant.int 4 - %int128_795 = torch.constant.int 128 - %954 = torch.prim.ListConstruct %int4_792, %953, %int8_793, %int4_794, %int128_795 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_796 = torch.constant.bool false - %955 = torch.aten.expand %952, %954, %false_796 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %955, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_797 = torch.constant.int 0 - %956 = torch.aten.clone %955, %int0_797 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %956, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_798 = torch.constant.int 4 - %int32_799 = torch.constant.int 32 - %int128_800 = torch.constant.int 128 - %957 = torch.prim.ListConstruct %int4_798, %953, %int32_799, %int128_800 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %958 = torch.aten._unsafe_view %956, %957 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %958, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_801 = torch.constant.int 1 - %int2_802 = torch.constant.int 2 - %959 = torch.aten.transpose.int %839, %int1_801, %int2_802 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_803 = torch.constant.int 1 - %int2_804 = torch.constant.int 2 - %960 = torch.aten.transpose.int %951, %int1_803, %int2_804 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %960, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_805 = torch.constant.int 1 - %int2_806 = torch.constant.int 2 - %961 = torch.aten.transpose.int %958, %int1_805, %int2_806 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %961, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_807 = torch.constant.float 0.000000e+00 - %false_808 = torch.constant.bool false - %none_809 = torch.constant.none - %962:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%959, %960, %961, %float0.000000e00_807, %false_808, %368, %none_809) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %1142 = torch.aten.add.Tensor %963, %1141, %int1_778 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_779 = torch.constant.int 6 + %1143 = torch.prims.convert_element_type %1142, %int6_779 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_780 = torch.constant.int 2 + %1144 = torch.aten.pow.Tensor_Scalar %1143, %int2_780 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_781 = torch.constant.int -1 + %1145 = torch.prim.ListConstruct %int-1_781 : (!torch.int) -> !torch.list + %true_782 = torch.constant.bool true + %none_783 = torch.constant.none + %1146 = torch.aten.mean.dim %1144, %1145, %true_782, %none_783 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_784 = torch.constant.float 9.9999997473787516E-6 + %int1_785 = torch.constant.int 1 + %1147 = torch.aten.add.Scalar %1146, %float9.999990e-06_784, %int1_785 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %1148 = torch.aten.rsqrt %1147 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %1149 = torch.aten.mul.Tensor %1143, %1148 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_786 = torch.constant.int 5 + %1150 = torch.prims.convert_element_type %1149, %int5_786 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %1151 = torch.aten.mul.Tensor %40, %1150 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_787 = torch.constant.int 5 + %1152 = torch.prims.convert_element_type %1151, %int5_787 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_788 = torch.constant.int -2 + %int-1_789 = torch.constant.int -1 + %1153 = torch.aten.transpose.int %41, %int-2_788, %int-1_789 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_790 = torch.constant.int 5 + %1154 = torch.prims.convert_element_type %1153, %int5_790 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_791 = torch.constant.int 4 + %int4096_792 = torch.constant.int 4096 + %1155 = torch.prim.ListConstruct %int4_791, %int4096_792 : (!torch.int, !torch.int) -> !torch.list + %1156 = torch.aten.view %1152, %1155 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1157 = torch.aten.mm %1156, %1154 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_793 = torch.constant.int 4 + %int1_794 = torch.constant.int 1 + %int14336_795 = torch.constant.int 14336 + %1158 = torch.prim.ListConstruct %int4_793, %int1_794, %int14336_795 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1159 = torch.aten.view %1157, %1158 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %1160 = torch.aten.silu %1159 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_796 = torch.constant.int -2 + %int-1_797 = torch.constant.int -1 + %1161 = torch.aten.transpose.int %42, %int-2_796, %int-1_797 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_798 = torch.constant.int 5 + %1162 = torch.prims.convert_element_type %1161, %int5_798 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_799 = torch.constant.int 4 + %int4096_800 = torch.constant.int 4096 + %1163 = torch.prim.ListConstruct %int4_799, %int4096_800 : (!torch.int, !torch.int) -> !torch.list + %1164 = torch.aten.view %1152, %1163 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1165 = torch.aten.mm %1164, %1162 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_801 = torch.constant.int 4 + %int1_802 = torch.constant.int 1 + %int14336_803 = torch.constant.int 14336 + %1166 = torch.prim.ListConstruct %int4_801, %int1_802, %int14336_803 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1167 = torch.aten.view %1165, %1166 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %1168 = torch.aten.mul.Tensor %1160, %1167 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_804 = torch.constant.int -2 + %int-1_805 = torch.constant.int -1 + %1169 = torch.aten.transpose.int %43, %int-2_804, %int-1_805 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_806 = torch.constant.int 5 + %1170 = torch.prims.convert_element_type %1169, %int5_806 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_807 = torch.constant.int 4 + %int14336_808 = torch.constant.int 14336 + %1171 = torch.prim.ListConstruct %int4_807, %int14336_808 : (!torch.int, !torch.int) -> !torch.list + %1172 = torch.aten.view %1168, %1171 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %1173 = torch.aten.mm %1172, %1170 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_809 = torch.constant.int 4 %int1_810 = torch.constant.int 1 - %int2_811 = torch.constant.int 2 - %963 = torch.aten.transpose.int %962#0, %int1_810, %int2_811 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_812 = torch.constant.int 4 - %int1_813 = torch.constant.int 1 - %int4096_814 = torch.constant.int 4096 - %964 = torch.prim.ListConstruct %int4_812, %int1_813, %int4096_814 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %965 = torch.aten.view %963, %964 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_815 = torch.constant.int -2 - %int-1_816 = torch.constant.int -1 - %966 = torch.aten.transpose.int %29, %int-2_815, %int-1_816 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_817 = torch.constant.int 4 - %int4096_818 = torch.constant.int 4096 - %967 = torch.prim.ListConstruct %int4_817, %int4096_818 : (!torch.int, !torch.int) -> !torch.list - %968 = torch.aten.view %965, %967 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %969 = torch.aten.mm %968, %966 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_819 = torch.constant.int 4 - %int1_820 = torch.constant.int 1 - %int4096_821 = torch.constant.int 4096 - %970 = torch.prim.ListConstruct %int4_819, %int1_820, %int4096_821 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %971 = torch.aten.view %969, %970 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_822 = torch.constant.int 1 - %972 = torch.aten.add.Tensor %799, %971, %int1_822 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_823 = torch.constant.int 6 - %973 = torch.prims.convert_element_type %972, %int6_823 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_824 = torch.constant.int 2 - %974 = torch.aten.pow.Tensor_Scalar %973, %int2_824 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_825 = torch.constant.int -1 - %975 = torch.prim.ListConstruct %int-1_825 : (!torch.int) -> !torch.list - %true_826 = torch.constant.bool true - %none_827 = torch.constant.none - %976 = torch.aten.mean.dim %974, %975, %true_826, %none_827 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_828 = torch.constant.float 9.9999997473787516E-6 - %int1_829 = torch.constant.int 1 - %977 = torch.aten.add.Scalar %976, %float9.999990e-06_828, %int1_829 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %978 = torch.aten.rsqrt %977 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %979 = torch.aten.mul.Tensor %973, %978 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_830 = torch.constant.int 5 - %980 = torch.prims.convert_element_type %979, %int5_830 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %981 = torch.aten.mul.Tensor %30, %980 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_831 = torch.constant.int 5 - %982 = torch.prims.convert_element_type %981, %int5_831 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_832 = torch.constant.int -2 - %int-1_833 = torch.constant.int -1 - %983 = torch.aten.transpose.int %31, %int-2_832, %int-1_833 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_834 = torch.constant.int 4 - %int4096_835 = torch.constant.int 4096 - %984 = torch.prim.ListConstruct %int4_834, %int4096_835 : (!torch.int, !torch.int) -> !torch.list - %985 = torch.aten.view %982, %984 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %986 = torch.aten.mm %985, %983 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_836 = torch.constant.int 4 - %int1_837 = torch.constant.int 1 - %int14336_838 = torch.constant.int 14336 - %987 = torch.prim.ListConstruct %int4_836, %int1_837, %int14336_838 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %988 = torch.aten.view %986, %987 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %989 = torch.aten.silu %988 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_839 = torch.constant.int -2 - %int-1_840 = torch.constant.int -1 - %990 = torch.aten.transpose.int %32, %int-2_839, %int-1_840 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_811 = torch.constant.int 4096 + %1174 = torch.prim.ListConstruct %int4_809, %int1_810, %int4096_811 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1175 = torch.aten.view %1173, %1174 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_812 = torch.constant.int 1 + %1176 = torch.aten.add.Tensor %1142, %1175, %int1_812 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_813 = torch.constant.int 6 + %1177 = torch.prims.convert_element_type %1176, %int6_813 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_814 = torch.constant.int 2 + %1178 = torch.aten.pow.Tensor_Scalar %1177, %int2_814 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_815 = torch.constant.int -1 + %1179 = torch.prim.ListConstruct %int-1_815 : (!torch.int) -> !torch.list + %true_816 = torch.constant.bool true + %none_817 = torch.constant.none + %1180 = torch.aten.mean.dim %1178, %1179, %true_816, %none_817 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_818 = torch.constant.float 9.9999997473787516E-6 + %int1_819 = torch.constant.int 1 + %1181 = torch.aten.add.Scalar %1180, %float9.999990e-06_818, %int1_819 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %1182 = torch.aten.rsqrt %1181 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %1183 = torch.aten.mul.Tensor %1177, %1182 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_820 = torch.constant.int 5 + %1184 = torch.prims.convert_element_type %1183, %int5_820 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %1185 = torch.aten.mul.Tensor %44, %1184 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_821 = torch.constant.int 5 + %1186 = torch.prims.convert_element_type %1185, %int5_821 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_822 = torch.constant.int -2 + %int-1_823 = torch.constant.int -1 + %1187 = torch.aten.transpose.int %45, %int-2_822, %int-1_823 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_824 = torch.constant.int 5 + %1188 = torch.prims.convert_element_type %1187, %int5_824 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_825 = torch.constant.int 4 + %int4096_826 = torch.constant.int 4096 + %1189 = torch.prim.ListConstruct %int4_825, %int4096_826 : (!torch.int, !torch.int) -> !torch.list + %1190 = torch.aten.view %1186, %1189 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1191 = torch.aten.mm %1190, %1188 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_827 = torch.constant.int 4 + %int1_828 = torch.constant.int 1 + %int4096_829 = torch.constant.int 4096 + %1192 = torch.prim.ListConstruct %int4_827, %int1_828, %int4096_829 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1193 = torch.aten.view %1191, %1192 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_830 = torch.constant.int -2 + %int-1_831 = torch.constant.int -1 + %1194 = torch.aten.transpose.int %46, %int-2_830, %int-1_831 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_832 = torch.constant.int 5 + %1195 = torch.prims.convert_element_type %1194, %int5_832 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_833 = torch.constant.int 4 + %int4096_834 = torch.constant.int 4096 + %1196 = torch.prim.ListConstruct %int4_833, %int4096_834 : (!torch.int, !torch.int) -> !torch.list + %1197 = torch.aten.view %1186, %1196 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1198 = torch.aten.mm %1197, %1195 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_835 = torch.constant.int 4 + %int1_836 = torch.constant.int 1 + %int1024_837 = torch.constant.int 1024 + %1199 = torch.prim.ListConstruct %int4_835, %int1_836, %int1024_837 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1200 = torch.aten.view %1198, %1199 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_838 = torch.constant.int -2 + %int-1_839 = torch.constant.int -1 + %1201 = torch.aten.transpose.int %47, %int-2_838, %int-1_839 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_840 = torch.constant.int 5 + %1202 = torch.prims.convert_element_type %1201, %int5_840 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> %int4_841 = torch.constant.int 4 %int4096_842 = torch.constant.int 4096 - %991 = torch.prim.ListConstruct %int4_841, %int4096_842 : (!torch.int, !torch.int) -> !torch.list - %992 = torch.aten.view %982, %991 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %993 = torch.aten.mm %992, %990 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %1203 = torch.prim.ListConstruct %int4_841, %int4096_842 : (!torch.int, !torch.int) -> !torch.list + %1204 = torch.aten.view %1186, %1203 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1205 = torch.aten.mm %1204, %1202 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> %int4_843 = torch.constant.int 4 %int1_844 = torch.constant.int 1 - %int14336_845 = torch.constant.int 14336 - %994 = torch.prim.ListConstruct %int4_843, %int1_844, %int14336_845 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %995 = torch.aten.view %993, %994 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %996 = torch.aten.mul.Tensor %989, %995 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_846 = torch.constant.int -2 - %int-1_847 = torch.constant.int -1 - %997 = torch.aten.transpose.int %33, %int-2_846, %int-1_847 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_848 = torch.constant.int 4 - %int14336_849 = torch.constant.int 14336 - %998 = torch.prim.ListConstruct %int4_848, %int14336_849 : (!torch.int, !torch.int) -> !torch.list - %999 = torch.aten.view %996, %998 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %1000 = torch.aten.mm %999, %997 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int1024_845 = torch.constant.int 1024 + %1206 = torch.prim.ListConstruct %int4_843, %int1_844, %int1024_845 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1207 = torch.aten.view %1205, %1206 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_846 = torch.constant.int 4 + %int1_847 = torch.constant.int 1 + %int32_848 = torch.constant.int 32 + %int128_849 = torch.constant.int 128 + %1208 = torch.prim.ListConstruct %int4_846, %int1_847, %int32_848, %int128_849 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1209 = torch.aten.view %1193, %1208 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> %int4_850 = torch.constant.int 4 %int1_851 = torch.constant.int 1 - %int4096_852 = torch.constant.int 4096 - %1001 = torch.prim.ListConstruct %int4_850, %int1_851, %int4096_852 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1002 = torch.aten.view %1000, %1001 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_853 = torch.constant.int 1 - %1003 = torch.aten.add.Tensor %972, %1002, %int1_853 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_854 = torch.constant.int 6 - %1004 = torch.prims.convert_element_type %1003, %int6_854 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_855 = torch.constant.int 2 - %1005 = torch.aten.pow.Tensor_Scalar %1004, %int2_855 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_856 = torch.constant.int -1 - %1006 = torch.prim.ListConstruct %int-1_856 : (!torch.int) -> !torch.list - %true_857 = torch.constant.bool true - %none_858 = torch.constant.none - %1007 = torch.aten.mean.dim %1005, %1006, %true_857, %none_858 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_859 = torch.constant.float 9.9999997473787516E-6 - %int1_860 = torch.constant.int 1 - %1008 = torch.aten.add.Scalar %1007, %float9.999990e-06_859, %int1_860 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %1009 = torch.aten.rsqrt %1008 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %1010 = torch.aten.mul.Tensor %1004, %1009 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_861 = torch.constant.int 5 - %1011 = torch.prims.convert_element_type %1010, %int5_861 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %1012 = torch.aten.mul.Tensor %34, %1011 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_862 = torch.constant.int 5 - %1013 = torch.prims.convert_element_type %1012, %int5_862 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_863 = torch.constant.int -2 - %int-1_864 = torch.constant.int -1 - %1014 = torch.aten.transpose.int %35, %int-2_863, %int-1_864 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_865 = torch.constant.int 4 - %int4096_866 = torch.constant.int 4096 - %1015 = torch.prim.ListConstruct %int4_865, %int4096_866 : (!torch.int, !torch.int) -> !torch.list - %1016 = torch.aten.view %1013, %1015 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1017 = torch.aten.mm %1016, %1014 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_867 = torch.constant.int 4 - %int1_868 = torch.constant.int 1 - %int4096_869 = torch.constant.int 4096 - %1018 = torch.prim.ListConstruct %int4_867, %int1_868, %int4096_869 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1019 = torch.aten.view %1017, %1018 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_870 = torch.constant.int -2 - %int-1_871 = torch.constant.int -1 - %1020 = torch.aten.transpose.int %36, %int-2_870, %int-1_871 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_872 = torch.constant.int 4 - %int4096_873 = torch.constant.int 4096 - %1021 = torch.prim.ListConstruct %int4_872, %int4096_873 : (!torch.int, !torch.int) -> !torch.list - %1022 = torch.aten.view %1013, %1021 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1023 = torch.aten.mm %1022, %1020 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_874 = torch.constant.int 4 - %int1_875 = torch.constant.int 1 - %int1024_876 = torch.constant.int 1024 - %1024 = torch.prim.ListConstruct %int4_874, %int1_875, %int1024_876 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1025 = torch.aten.view %1023, %1024 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_877 = torch.constant.int -2 - %int-1_878 = torch.constant.int -1 - %1026 = torch.aten.transpose.int %37, %int-2_877, %int-1_878 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_879 = torch.constant.int 4 - %int4096_880 = torch.constant.int 4096 - %1027 = torch.prim.ListConstruct %int4_879, %int4096_880 : (!torch.int, !torch.int) -> !torch.list - %1028 = torch.aten.view %1013, %1027 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1029 = torch.aten.mm %1028, %1026 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_881 = torch.constant.int 4 - %int1_882 = torch.constant.int 1 - %int1024_883 = torch.constant.int 1024 - %1030 = torch.prim.ListConstruct %int4_881, %int1_882, %int1024_883 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1031 = torch.aten.view %1029, %1030 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_884 = torch.constant.int 4 - %int1_885 = torch.constant.int 1 + %int8_852 = torch.constant.int 8 + %int128_853 = torch.constant.int 128 + %1210 = torch.prim.ListConstruct %int4_850, %int1_851, %int8_852, %int128_853 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1211 = torch.aten.view %1200, %1210 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_854 = torch.constant.int 4 + %int1_855 = torch.constant.int 1 + %int8_856 = torch.constant.int 8 + %int128_857 = torch.constant.int 128 + %1212 = torch.prim.ListConstruct %int4_854, %int1_855, %int8_856, %int128_857 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1213 = torch.aten.view %1207, %1212 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_858 = torch.constant.int 1 + %int2_859 = torch.constant.int 2 + %1214 = torch.aten.transpose.int %1209, %int1_858, %int2_859 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %1215 = torch.aten.mul.Tensor %1214, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_860 = torch.constant.int 3 + %int0_861 = torch.constant.int 0 + %int64_862 = torch.constant.int 64 + %int1_863 = torch.constant.int 1 + %1216 = torch.aten.slice.Tensor %1214, %int3_860, %int0_861, %int64_862, %int1_863 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_864 = torch.constant.int 3 + %int64_865 = torch.constant.int 64 + %int9223372036854775807_866 = torch.constant.int 9223372036854775807 + %int1_867 = torch.constant.int 1 + %1217 = torch.aten.slice.Tensor %1214, %int3_864, %int64_865, %int9223372036854775807_866, %int1_867 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %1218 = torch.aten.neg %1217 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %1219 = torch.prim.ListConstruct %1218, %1216 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_868 = torch.constant.int -1 + %1220 = torch.aten.cat %1219, %int-1_868 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %1221 = torch.aten.mul.Tensor %1220, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_869 = torch.constant.int 1 + %1222 = torch.aten.add.Tensor %1215, %1221, %int1_869 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_870 = torch.constant.int 1 + %int2_871 = torch.constant.int 2 + %1223 = torch.aten.transpose.int %1222, %int1_870, %int2_871 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_872 = torch.constant.int 1 + %int2_873 = torch.constant.int 2 + %1224 = torch.aten.transpose.int %1211, %int1_872, %int2_873 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %1225 = torch.aten.mul.Tensor %1224, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_874 = torch.constant.int 3 + %int0_875 = torch.constant.int 0 + %int64_876 = torch.constant.int 64 + %int1_877 = torch.constant.int 1 + %1226 = torch.aten.slice.Tensor %1224, %int3_874, %int0_875, %int64_876, %int1_877 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_878 = torch.constant.int 3 + %int64_879 = torch.constant.int 64 + %int9223372036854775807_880 = torch.constant.int 9223372036854775807 + %int1_881 = torch.constant.int 1 + %1227 = torch.aten.slice.Tensor %1224, %int3_878, %int64_879, %int9223372036854775807_880, %int1_881 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %1228 = torch.aten.neg %1227 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %1229 = torch.prim.ListConstruct %1228, %1226 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_882 = torch.constant.int -1 + %1230 = torch.aten.cat %1229, %int-1_882 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %1231 = torch.aten.mul.Tensor %1230, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_883 = torch.constant.int 1 + %1232 = torch.aten.add.Tensor %1225, %1231, %int1_883 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_884 = torch.constant.int 1 + %int2_885 = torch.constant.int 2 + %1233 = torch.aten.transpose.int %1232, %int1_884, %int2_885 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> %int32_886 = torch.constant.int 32 - %int128_887 = torch.constant.int 128 - %1032 = torch.prim.ListConstruct %int4_884, %int1_885, %int32_886, %int128_887 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1033 = torch.aten.view %1019, %1032 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_888 = torch.constant.int 4 - %int1_889 = torch.constant.int 1 - %int8_890 = torch.constant.int 8 - %int128_891 = torch.constant.int 128 - %1034 = torch.prim.ListConstruct %int4_888, %int1_889, %int8_890, %int128_891 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1035 = torch.aten.view %1025, %1034 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_892 = torch.constant.int 4 - %int1_893 = torch.constant.int 1 - %int8_894 = torch.constant.int 8 - %int128_895 = torch.constant.int 128 - %1036 = torch.prim.ListConstruct %int4_892, %int1_893, %int8_894, %int128_895 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1037 = torch.aten.view %1031, %1036 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_896 = torch.constant.int 6 - %1038 = torch.prims.convert_element_type %1033, %int6_896 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %1039 = torch_c.to_builtin_tensor %1038 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %1040 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %1041 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%1039, %1040) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %1042 = torch_c.from_builtin_tensor %1041 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_897 = torch.constant.int 5 - %1043 = torch.prims.convert_element_type %1042, %int5_897 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_898 = torch.constant.int 6 - %1044 = torch.prims.convert_element_type %1035, %int6_898 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %1045 = torch_c.to_builtin_tensor %1044 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %1046 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %1047 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%1045, %1046) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %1048 = torch_c.from_builtin_tensor %1047 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_899 = torch.constant.int 5 - %1049 = torch.prims.convert_element_type %1048, %int5_899 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_900 = torch.constant.int 32 - %1050 = torch.aten.floor_divide.Scalar %arg2, %int32_900 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_901 = torch.constant.int 1 - %1051 = torch.aten.unsqueeze %1050, %int1_901 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %1234 = torch.aten.floor_divide.Scalar %arg2, %int32_886 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_887 = torch.constant.int 1 + %1235 = torch.aten.unsqueeze %1234, %int1_887 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_888 = torch.constant.int 1 + %false_889 = torch.constant.bool false + %1236 = torch.aten.gather %arg3, %int1_888, %1235, %false_889 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_890 = torch.constant.int 4 + %int1_891 = torch.constant.int 1 + %int1_892 = torch.constant.int 1 + %1237 = torch.prim.ListConstruct %int4_890, %int1_891, %int1_892 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1238 = torch.aten.view %1236, %1237 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_893 = torch.constant.int 32 + %1239 = torch.aten.remainder.Scalar %arg2, %int32_893 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_894 = torch.constant.int 4 + %int1_895 = torch.constant.int 1 + %int1_896 = torch.constant.int 1 + %1240 = torch.prim.ListConstruct %int4_894, %int1_895, %int1_896 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1241 = torch.aten.view %1239, %1240 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_897 = torch.constant.int 8 + %none_898 = torch.constant.none + %none_899 = torch.constant.none + %cpu_900 = torch.constant.device "cpu" + %false_901 = torch.constant.bool false + %1242 = torch.aten.arange %int8_897, %none_898, %none_899, %cpu_900, %false_901 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> %int1_902 = torch.constant.int 1 - %false_903 = torch.constant.bool false - %1052 = torch.aten.gather %arg3, %int1_902, %1051, %false_903 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_904 = torch.constant.int 32 - %1053 = torch.aten.remainder.Scalar %arg2, %int32_904 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_905 = torch.constant.int 1 - %1054 = torch.aten.unsqueeze %1053, %int1_905 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_906 = torch.constant.none - %1055 = torch.aten.clone %38, %none_906 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_907 = torch.constant.int 0 - %1056 = torch.aten.unsqueeze %1055, %int0_907 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_908 = torch.constant.int 4 - %int1_909 = torch.constant.int 1 - %1057 = torch.prim.ListConstruct %int4_908, %int1_909 : (!torch.int, !torch.int) -> !torch.list - %int1_910 = torch.constant.int 1 + %int1_903 = torch.constant.int 1 + %int8_904 = torch.constant.int 8 + %1243 = torch.prim.ListConstruct %int1_902, %int1_903, %int8_904 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1244 = torch.aten.view %1242, %1243 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_905 = torch.constant.none + %1245 = torch.aten.clone %48, %none_905 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1246 = torch.aten.detach %1245 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1247 = torch.aten.detach %1246 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1248 = torch.aten.detach %1247 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_906 = torch.constant.int 1 + %int1_907 = torch.constant.int 1 + %int1_908 = torch.constant.int 1 + %1249 = torch.prim.ListConstruct %int1_906, %int1_907, %int1_908 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1250 = torch.aten.view %1248, %1249 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_909 = torch.constant.int 32 + %1251 = torch.aten.mul.Scalar %1238, %int32_909 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int3_910 = torch.constant.int 3 %int1_911 = torch.constant.int 1 - %1058 = torch.prim.ListConstruct %int1_910, %int1_911 : (!torch.int, !torch.int) -> !torch.list - %int4_912 = torch.constant.int 4 - %int0_913 = torch.constant.int 0 - %cpu_914 = torch.constant.device "cpu" - %false_915 = torch.constant.bool false - %1059 = torch.aten.empty_strided %1057, %1058, %int4_912, %int0_913, %cpu_914, %false_915 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int3 = torch.constant.int 3 - %1060 = torch.aten.fill.Scalar %1059, %int3 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_916 = torch.constant.int 4 + %1252 = torch.aten.add.Scalar %1251, %int3_910, %int1_911 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_912 = torch.constant.int 2 + %1253 = torch.aten.mul.Scalar %1252, %int2_912 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_913 = torch.constant.int 1 + %1254 = torch.aten.add.Tensor %1253, %1250, %int1_913 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_914 = torch.constant.int 8 + %1255 = torch.aten.mul.Scalar %1254, %int8_914 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_915 = torch.constant.int 1 + %1256 = torch.aten.add.Tensor %1255, %1244, %int1_915 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_916 = torch.constant.int 32 + %1257 = torch.aten.mul.Scalar %1256, %int32_916 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_917 = torch.constant.int 1 - %1061 = torch.prim.ListConstruct %int4_916, %int1_917 : (!torch.int, !torch.int) -> !torch.list - %1062 = torch.aten.repeat %1056, %1061 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_918 = torch.constant.int 32 - %1063 = torch.aten.mul.Scalar %1052, %int32_918 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_919 = torch.constant.int 1 - %1064 = torch.aten.add.Tensor %1063, %1060, %int1_919 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %1258 = torch.aten.add.Tensor %1257, %1241, %int1_917 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_918 = torch.constant.int 5 + %1259 = torch.prims.convert_element_type %1233, %int5_918 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_919 = torch.constant.int 32 %int2_920 = torch.constant.int 2 - %1065 = torch.aten.mul.Scalar %1064, %int2_920 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_921 = torch.constant.int 1 - %1066 = torch.aten.add.Tensor %1065, %1062, %int1_921 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int8_921 = torch.constant.int 8 %int32_922 = torch.constant.int 32 - %1067 = torch.aten.mul.Scalar %1066, %int32_922 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_923 = torch.constant.int 1 - %1068 = torch.aten.add.Tensor %1067, %1054, %int1_923 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_924 = torch.constant.int 32 - %int2_925 = torch.constant.int 2 + %int128_923 = torch.constant.int 128 + %1260 = torch.prim.ListConstruct %456, %int32_919, %int2_920, %int8_921, %int32_922, %int128_923 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1261 = torch.aten.view %1081, %1260 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1261, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_924 = torch.constant.int 128 + %1262 = torch.prim.ListConstruct %596, %int128_924 : (!torch.int, !torch.int) -> !torch.list + %1263 = torch.aten.view %1261, %1262 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1263, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %1264 = torch.prim.ListConstruct %1258 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_925 = torch.constant.bool false + %1265 = torch.aten.index_put %1263, %1264, %1259, %false_925 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1265, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> %int32_926 = torch.constant.int 32 - %int8_927 = torch.constant.int 8 - %int128_928 = torch.constant.int 128 - %1069 = torch.prim.ListConstruct %437, %int32_924, %int2_925, %int32_926, %int8_927, %int128_928 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1070 = torch.aten.view %906, %1069 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1070, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> + %int2_927 = torch.constant.int 2 + %int8_928 = torch.constant.int 8 %int32_929 = torch.constant.int 32 - %1071 = torch.aten.mul.int %437, %int32_929 : !torch.int, !torch.int -> !torch.int - %int2_930 = torch.constant.int 2 - %1072 = torch.aten.mul.int %1071, %int2_930 : !torch.int, !torch.int -> !torch.int - %int32_931 = torch.constant.int 32 - %1073 = torch.aten.mul.int %1072, %int32_931 : !torch.int, !torch.int -> !torch.int - %int8_932 = torch.constant.int 8 - %int128_933 = torch.constant.int 128 - %1074 = torch.prim.ListConstruct %1073, %int8_932, %int128_933 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1075 = torch.aten.view %1070, %1074 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1075, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %1076 = torch.prim.ListConstruct %1068 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_934 = torch.constant.bool false - %1077 = torch.aten.index_put %1075, %1076, %1049, %false_934 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1077, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> + %int128_930 = torch.constant.int 128 + %1266 = torch.prim.ListConstruct %456, %int32_926, %int2_927, %int8_928, %int32_929, %int128_930 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1267 = torch.aten.view %1265, %1266 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1267, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_931 = torch.constant.int 2097152 + %1268 = torch.prim.ListConstruct %456, %int2097152_931 : (!torch.int, !torch.int) -> !torch.list + %1269 = torch.aten.view %1267, %1268 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1269, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_932 = torch.constant.int 32 + %int2_933 = torch.constant.int 2 + %int8_934 = torch.constant.int 8 %int32_935 = torch.constant.int 32 - %int2_936 = torch.constant.int 2 - %int32_937 = torch.constant.int 32 - %int8_938 = torch.constant.int 8 - %int128_939 = torch.constant.int 128 - %1078 = torch.prim.ListConstruct %437, %int32_935, %int2_936, %int32_937, %int8_938, %int128_939 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1079 = torch.aten.view %1077, %1078 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1079, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_940 = torch.constant.int 2097152 - %1080 = torch.prim.ListConstruct %437, %int2097152_940 : (!torch.int, !torch.int) -> !torch.list - %1081 = torch.aten.view %1079, %1080 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1081, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_941 = torch.constant.int 32 - %int2_942 = torch.constant.int 2 - %int32_943 = torch.constant.int 32 - %int8_944 = torch.constant.int 8 - %int128_945 = torch.constant.int 128 - %1082 = torch.prim.ListConstruct %437, %int32_941, %int2_942, %int32_943, %int8_944, %int128_945 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1083 = torch.aten.view %1081, %1082 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1083, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_946 = torch.constant.int 8 - %int128_947 = torch.constant.int 128 - %1084 = torch.prim.ListConstruct %1073, %int8_946, %int128_947 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1085 = torch.aten.view %1083, %1084 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1085, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_948 = torch.constant.int 32 - %1086 = torch.aten.floor_divide.Scalar %arg2, %int32_948 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_949 = torch.constant.int 1 - %1087 = torch.aten.unsqueeze %1086, %int1_949 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int128_936 = torch.constant.int 128 + %1270 = torch.prim.ListConstruct %456, %int32_932, %int2_933, %int8_934, %int32_935, %int128_936 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1271 = torch.aten.view %1269, %1270 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1271, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_937 = torch.constant.int 128 + %1272 = torch.prim.ListConstruct %596, %int128_937 : (!torch.int, !torch.int) -> !torch.list + %1273 = torch.aten.view %1271, %1272 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1273, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_938 = torch.constant.none + %1274 = torch.aten.clone %49, %none_938 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1275 = torch.aten.detach %1274 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1276 = torch.aten.detach %1275 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1277 = torch.aten.detach %1276 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_939 = torch.constant.int 1 + %int1_940 = torch.constant.int 1 + %int1_941 = torch.constant.int 1 + %1278 = torch.prim.ListConstruct %int1_939, %int1_940, %int1_941 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1279 = torch.aten.view %1277, %1278 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_942 = torch.constant.int 32 + %1280 = torch.aten.mul.Scalar %1238, %int32_942 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int3_943 = torch.constant.int 3 + %int1_944 = torch.constant.int 1 + %1281 = torch.aten.add.Scalar %1280, %int3_943, %int1_944 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_945 = torch.constant.int 2 + %1282 = torch.aten.mul.Scalar %1281, %int2_945 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_946 = torch.constant.int 1 + %1283 = torch.aten.add.Tensor %1282, %1279, %int1_946 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_947 = torch.constant.int 8 + %1284 = torch.aten.mul.Scalar %1283, %int8_947 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_948 = torch.constant.int 1 + %1285 = torch.aten.add.Tensor %1284, %1244, %int1_948 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_949 = torch.constant.int 32 + %1286 = torch.aten.mul.Scalar %1285, %int32_949 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_950 = torch.constant.int 1 - %false_951 = torch.constant.bool false - %1088 = torch.aten.gather %arg3, %int1_950, %1087, %false_951 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_952 = torch.constant.int 32 - %1089 = torch.aten.remainder.Scalar %arg2, %int32_952 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_953 = torch.constant.int 1 - %1090 = torch.aten.unsqueeze %1089, %int1_953 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_954 = torch.constant.none - %1091 = torch.aten.clone %39, %none_954 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_955 = torch.constant.int 0 - %1092 = torch.aten.unsqueeze %1091, %int0_955 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_956 = torch.constant.int 4 - %int1_957 = torch.constant.int 1 - %1093 = torch.prim.ListConstruct %int4_956, %int1_957 : (!torch.int, !torch.int) -> !torch.list - %int1_958 = torch.constant.int 1 - %int1_959 = torch.constant.int 1 - %1094 = torch.prim.ListConstruct %int1_958, %int1_959 : (!torch.int, !torch.int) -> !torch.list - %int4_960 = torch.constant.int 4 - %int0_961 = torch.constant.int 0 - %cpu_962 = torch.constant.device "cpu" - %false_963 = torch.constant.bool false - %1095 = torch.aten.empty_strided %1093, %1094, %int4_960, %int0_961, %cpu_962, %false_963 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int3_964 = torch.constant.int 3 - %1096 = torch.aten.fill.Scalar %1095, %int3_964 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_965 = torch.constant.int 4 - %int1_966 = torch.constant.int 1 - %1097 = torch.prim.ListConstruct %int4_965, %int1_966 : (!torch.int, !torch.int) -> !torch.list - %1098 = torch.aten.repeat %1092, %1097 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_967 = torch.constant.int 32 - %1099 = torch.aten.mul.Scalar %1088, %int32_967 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_968 = torch.constant.int 1 - %1100 = torch.aten.add.Tensor %1099, %1096, %int1_968 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_969 = torch.constant.int 2 - %1101 = torch.aten.mul.Scalar %1100, %int2_969 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_970 = torch.constant.int 1 - %1102 = torch.aten.add.Tensor %1101, %1098, %int1_970 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_971 = torch.constant.int 32 - %1103 = torch.aten.mul.Scalar %1102, %int32_971 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_972 = torch.constant.int 1 - %1104 = torch.aten.add.Tensor %1103, %1090, %int1_972 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %1105 = torch.prim.ListConstruct %1104 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_973 = torch.constant.bool false - %1106 = torch.aten.index_put %1085, %1105, %1037, %false_973 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1106, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_974 = torch.constant.int 32 - %int2_975 = torch.constant.int 2 - %int32_976 = torch.constant.int 32 - %int8_977 = torch.constant.int 8 - %int128_978 = torch.constant.int 128 - %1107 = torch.prim.ListConstruct %437, %int32_974, %int2_975, %int32_976, %int8_977, %int128_978 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1108 = torch.aten.view %1106, %1107 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1108, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_979 = torch.constant.int 2097152 - %1109 = torch.prim.ListConstruct %437, %int2097152_979 : (!torch.int, !torch.int) -> !torch.list - %1110 = torch.aten.view %1108, %1109 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1110, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %1287 = torch.aten.add.Tensor %1286, %1241, %int1_950 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_951 = torch.constant.int 5 + %1288 = torch.prims.convert_element_type %1213, %int5_951 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %1289 = torch.prim.ListConstruct %1287 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_952 = torch.constant.bool false + %1290 = torch.aten.index_put %1273, %1289, %1288, %false_952 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1290, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_953 = torch.constant.int 32 + %int2_954 = torch.constant.int 2 + %int8_955 = torch.constant.int 8 + %int32_956 = torch.constant.int 32 + %int128_957 = torch.constant.int 128 + %1291 = torch.prim.ListConstruct %456, %int32_953, %int2_954, %int8_955, %int32_956, %int128_957 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1292 = torch.aten.view %1290, %1291 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1292, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_958 = torch.constant.int 2097152 + %1293 = torch.prim.ListConstruct %456, %int2097152_958 : (!torch.int, !torch.int) -> !torch.list + %1294 = torch.aten.view %1292, %1293 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1294, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_959 = torch.constant.none + %1295 = torch.aten.clone %50, %none_959 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1296 = torch.aten.detach %1295 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1297 = torch.aten.detach %1296 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1298 = torch.aten.detach %1297 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_960 = torch.constant.none + %1299 = torch.aten.clone %51, %none_960 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1300 = torch.aten.detach %1299 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1301 = torch.aten.detach %1300 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1302 = torch.aten.detach %1301 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_961 = torch.constant.none + %1303 = torch.aten.clone %52, %none_961 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1304 = torch.aten.detach %1303 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1305 = torch.aten.detach %1304 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1306 = torch.aten.detach %1305 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_962 = torch.constant.int 32 + %int2_963 = torch.constant.int 2 + %int8_964 = torch.constant.int 8 + %int32_965 = torch.constant.int 32 + %int128_966 = torch.constant.int 128 + %1307 = torch.prim.ListConstruct %456, %int32_962, %int2_963, %int8_964, %int32_965, %int128_966 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1308 = torch.aten.view %1294, %1307 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1308, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %1309 = torch_c.to_builtin_tensor %1308 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %1310 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_967 = tensor.cast %1310 : tensor<4x?xi64> to tensor + %1311 = torch_c.to_builtin_tensor %1298 : !torch.vtensor<[],si64> -> tensor + %1312 = torch_c.to_builtin_tensor %1302 : !torch.vtensor<[],si64> -> tensor + %1313 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%1309, %cast_967, %1311, %1312) : (tensor, tensor, tensor, tensor) -> tensor + %cast_968 = tensor.cast %1313 : tensor to tensor<4x?x8x32x128xf16> + %1314 = torch_c.from_builtin_tensor %cast_968 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %1314, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %1315 = torch_c.to_builtin_tensor %1308 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %1316 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_969 = tensor.cast %1316 : tensor<4x?xi64> to tensor + %1317 = torch_c.to_builtin_tensor %1298 : !torch.vtensor<[],si64> -> tensor + %1318 = torch_c.to_builtin_tensor %1306 : !torch.vtensor<[],si64> -> tensor + %1319 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%1315, %cast_969, %1317, %1318) : (tensor, tensor, tensor, tensor) -> tensor + %cast_970 = tensor.cast %1319 : tensor to tensor<4x?x8x32x128xf16> + %1320 = torch_c.from_builtin_tensor %cast_970 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %1320, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_971 = torch.constant.int 2 + %int3_972 = torch.constant.int 3 + %1321 = torch.aten.transpose.int %1314, %int2_971, %int3_972 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1321, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_973 = torch.constant.int 0 + %1322 = torch.aten.clone %1321, %int0_973 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1322, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_974 = torch.constant.int 4 + %int8_975 = torch.constant.int 8 + %int128_976 = torch.constant.int 128 + %1323 = torch.prim.ListConstruct %int4_974, %457, %int8_975, %int128_976 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1324 = torch.aten._unsafe_view %1322, %1323 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1324, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_977 = torch.constant.int 2 + %int3_978 = torch.constant.int 3 + %1325 = torch.aten.transpose.int %1320, %int2_977, %int3_978 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1325, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_979 = torch.constant.int 0 + %1326 = torch.aten.clone %1325, %int0_979 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1326, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int4_980 = torch.constant.int 4 - %1111 = torch.prim.ListConstruct %int4_980, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_981 = torch.constant.int 1 - %1112 = torch.prim.ListConstruct %358, %int1_981 : (!torch.int, !torch.int) -> !torch.list - %int4_982 = torch.constant.int 4 - %int0_983 = torch.constant.int 0 - %cpu_984 = torch.constant.device "cpu" - %false_985 = torch.constant.bool false - %1113 = torch.aten.empty_strided %1111, %1112, %int4_982, %int0_983, %cpu_984, %false_985 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1113, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int3_986 = torch.constant.int 3 - %1114 = torch.aten.fill.Scalar %1113, %int3_986 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1114, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_987 = torch.constant.int 32 - %1115 = torch.aten.mul.Scalar %arg3, %int32_987 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1115, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_988 = torch.constant.int 1 - %1116 = torch.aten.add.Tensor %1115, %1114, %int1_988 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1116, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_989 = torch.constant.int 4 - %1117 = torch.aten.mul.int %int4_989, %358 : !torch.int, !torch.int -> !torch.int - %1118 = torch.prim.ListConstruct %1117 : (!torch.int) -> !torch.list - %1119 = torch.aten.view %1116, %1118 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1119, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_990 = torch.constant.int 32 - %int2_991 = torch.constant.int 2 - %int32_992 = torch.constant.int 32 - %int8_993 = torch.constant.int 8 - %int128_994 = torch.constant.int 128 - %1120 = torch.prim.ListConstruct %437, %int32_990, %int2_991, %int32_992, %int8_993, %int128_994 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1121 = torch.aten.view %1110, %1120 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1121, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_995 = torch.constant.int 32 - %1122 = torch.aten.mul.int %437, %int32_995 : !torch.int, !torch.int -> !torch.int - %int2_996 = torch.constant.int 2 - %int32_997 = torch.constant.int 32 - %int8_998 = torch.constant.int 8 - %int128_999 = torch.constant.int 128 - %1123 = torch.prim.ListConstruct %1122, %int2_996, %int32_997, %int8_998, %int128_999 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1124 = torch.aten.view %1121, %1123 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %1124, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_1000 = torch.constant.int 0 - %1125 = torch.aten.index_select %1124, %int0_1000, %1119 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %1125, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_1001 = torch.constant.int 4 - %int2_1002 = torch.constant.int 2 - %int32_1003 = torch.constant.int 32 - %int8_1004 = torch.constant.int 8 - %int128_1005 = torch.constant.int 128 - %1126 = torch.prim.ListConstruct %int4_1001, %358, %int2_1002, %int32_1003, %int8_1004, %int128_1005 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1127 = torch.aten.view %1125, %1126 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1127, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_1006 = torch.constant.int 0 - %int0_1007 = torch.constant.int 0 - %int9223372036854775807_1008 = torch.constant.int 9223372036854775807 - %int1_1009 = torch.constant.int 1 - %1128 = torch.aten.slice.Tensor %1127, %int0_1006, %int0_1007, %int9223372036854775807_1008, %int1_1009 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1128, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_1010 = torch.constant.int 1 - %int0_1011 = torch.constant.int 0 - %int9223372036854775807_1012 = torch.constant.int 9223372036854775807 - %int1_1013 = torch.constant.int 1 - %1129 = torch.aten.slice.Tensor %1128, %int1_1010, %int0_1011, %int9223372036854775807_1012, %int1_1013 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1129, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_1014 = torch.constant.int 2 - %int0_1015 = torch.constant.int 0 - %1130 = torch.aten.select.int %1129, %int2_1014, %int0_1015 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1130, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_1016 = torch.constant.int 32 - %1131 = torch.aten.mul.int %358, %int32_1016 : !torch.int, !torch.int -> !torch.int - %int2_1017 = torch.constant.int 2 - %int0_1018 = torch.constant.int 0 - %int1_1019 = torch.constant.int 1 - %1132 = torch.aten.slice.Tensor %1130, %int2_1017, %int0_1018, %1131, %int1_1019 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1132, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_1020 = torch.constant.int 0 - %1133 = torch.aten.clone %1132, %int0_1020 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1133, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_1021 = torch.constant.int 1 - %1134 = torch.aten.size.int %1129, %int1_1021 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_1022 = torch.constant.int 32 - %1135 = torch.aten.mul.int %1134, %int32_1022 : !torch.int, !torch.int -> !torch.int - %int4_1023 = torch.constant.int 4 - %int8_1024 = torch.constant.int 8 - %int128_1025 = torch.constant.int 128 - %1136 = torch.prim.ListConstruct %int4_1023, %1135, %int8_1024, %int128_1025 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1137 = torch.aten._unsafe_view %1133, %1136 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1137, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_1026 = torch.constant.int 0 - %int0_1027 = torch.constant.int 0 - %int9223372036854775807_1028 = torch.constant.int 9223372036854775807 - %int1_1029 = torch.constant.int 1 - %1138 = torch.aten.slice.Tensor %1137, %int0_1026, %int0_1027, %int9223372036854775807_1028, %int1_1029 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1138, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_1030 = torch.constant.int 0 - %int0_1031 = torch.constant.int 0 - %int9223372036854775807_1032 = torch.constant.int 9223372036854775807 - %int1_1033 = torch.constant.int 1 - %1139 = torch.aten.slice.Tensor %1127, %int0_1030, %int0_1031, %int9223372036854775807_1032, %int1_1033 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1139, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_1034 = torch.constant.int 1 - %int0_1035 = torch.constant.int 0 - %int9223372036854775807_1036 = torch.constant.int 9223372036854775807 - %int1_1037 = torch.constant.int 1 - %1140 = torch.aten.slice.Tensor %1139, %int1_1034, %int0_1035, %int9223372036854775807_1036, %int1_1037 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1140, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_1038 = torch.constant.int 2 - %int1_1039 = torch.constant.int 1 - %1141 = torch.aten.select.int %1140, %int2_1038, %int1_1039 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1141, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_1040 = torch.constant.int 2 - %int0_1041 = torch.constant.int 0 - %int1_1042 = torch.constant.int 1 - %1142 = torch.aten.slice.Tensor %1141, %int2_1040, %int0_1041, %1131, %int1_1042 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1142, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_1043 = torch.constant.int 0 - %1143 = torch.aten.clone %1142, %int0_1043 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1143, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_1044 = torch.constant.int 1 - %1144 = torch.aten.size.int %1140, %int1_1044 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_1045 = torch.constant.int 32 - %1145 = torch.aten.mul.int %1144, %int32_1045 : !torch.int, !torch.int -> !torch.int + %int8_981 = torch.constant.int 8 + %int128_982 = torch.constant.int 128 + %1327 = torch.prim.ListConstruct %int4_980, %457, %int8_981, %int128_982 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1328 = torch.aten._unsafe_view %1326, %1327 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1328, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_983 = torch.constant.int -2 + %1329 = torch.aten.unsqueeze %1324, %int-2_983 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1329, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_984 = torch.constant.int 4 + %int8_985 = torch.constant.int 8 + %int4_986 = torch.constant.int 4 + %int128_987 = torch.constant.int 128 + %1330 = torch.prim.ListConstruct %int4_984, %457, %int8_985, %int4_986, %int128_987 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_988 = torch.constant.bool false + %1331 = torch.aten.expand %1329, %1330, %false_988 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1331, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_989 = torch.constant.int 0 + %1332 = torch.aten.clone %1331, %int0_989 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1332, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_990 = torch.constant.int 4 + %int32_991 = torch.constant.int 32 + %int128_992 = torch.constant.int 128 + %1333 = torch.prim.ListConstruct %int4_990, %457, %int32_991, %int128_992 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1334 = torch.aten._unsafe_view %1332, %1333 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1334, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_993 = torch.constant.int -2 + %1335 = torch.aten.unsqueeze %1328, %int-2_993 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1335, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_994 = torch.constant.int 4 + %int8_995 = torch.constant.int 8 + %int4_996 = torch.constant.int 4 + %int128_997 = torch.constant.int 128 + %1336 = torch.prim.ListConstruct %int4_994, %457, %int8_995, %int4_996, %int128_997 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_998 = torch.constant.bool false + %1337 = torch.aten.expand %1335, %1336, %false_998 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1337, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_999 = torch.constant.int 0 + %1338 = torch.aten.clone %1337, %int0_999 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1338, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_1000 = torch.constant.int 4 + %int32_1001 = torch.constant.int 32 + %int128_1002 = torch.constant.int 128 + %1339 = torch.prim.ListConstruct %int4_1000, %457, %int32_1001, %int128_1002 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1340 = torch.aten._unsafe_view %1338, %1339 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1340, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_1003 = torch.constant.int 1 + %int2_1004 = torch.constant.int 2 + %1341 = torch.aten.transpose.int %1223, %int1_1003, %int2_1004 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_1005 = torch.constant.int 1 + %int2_1006 = torch.constant.int 2 + %1342 = torch.aten.transpose.int %1334, %int1_1005, %int2_1006 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1342, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_1007 = torch.constant.int 1 + %int2_1008 = torch.constant.int 2 + %1343 = torch.aten.transpose.int %1340, %int1_1007, %int2_1008 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1343, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_1009 = torch.constant.float 0.000000e+00 + %false_1010 = torch.constant.bool false + %none_1011 = torch.constant.none + %1344:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1341, %1342, %1343, %float0.000000e00_1009, %false_1010, %470, %none_1011) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_1012 = torch.constant.int 1 + %int2_1013 = torch.constant.int 2 + %1345 = torch.aten.transpose.int %1344#0, %int1_1012, %int2_1013 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_1014 = torch.constant.int 4 + %int1_1015 = torch.constant.int 1 + %int4096_1016 = torch.constant.int 4096 + %1346 = torch.prim.ListConstruct %int4_1014, %int1_1015, %int4096_1016 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1347 = torch.aten.view %1345, %1346 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_1017 = torch.constant.int -2 + %int-1_1018 = torch.constant.int -1 + %1348 = torch.aten.transpose.int %53, %int-2_1017, %int-1_1018 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_1019 = torch.constant.int 5 + %1349 = torch.prims.convert_element_type %1348, %int5_1019 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_1020 = torch.constant.int 4 + %int4096_1021 = torch.constant.int 4096 + %1350 = torch.prim.ListConstruct %int4_1020, %int4096_1021 : (!torch.int, !torch.int) -> !torch.list + %1351 = torch.aten.view %1347, %1350 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1352 = torch.aten.mm %1351, %1349 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_1022 = torch.constant.int 4 + %int1_1023 = torch.constant.int 1 + %int4096_1024 = torch.constant.int 4096 + %1353 = torch.prim.ListConstruct %int4_1022, %int1_1023, %int4096_1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1354 = torch.aten.view %1352, %1353 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_1025 = torch.constant.int 1 + %1355 = torch.aten.add.Tensor %1176, %1354, %int1_1025 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_1026 = torch.constant.int 6 + %1356 = torch.prims.convert_element_type %1355, %int6_1026 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_1027 = torch.constant.int 2 + %1357 = torch.aten.pow.Tensor_Scalar %1356, %int2_1027 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_1028 = torch.constant.int -1 + %1358 = torch.prim.ListConstruct %int-1_1028 : (!torch.int) -> !torch.list + %true_1029 = torch.constant.bool true + %none_1030 = torch.constant.none + %1359 = torch.aten.mean.dim %1357, %1358, %true_1029, %none_1030 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_1031 = torch.constant.float 9.9999997473787516E-6 + %int1_1032 = torch.constant.int 1 + %1360 = torch.aten.add.Scalar %1359, %float9.999990e-06_1031, %int1_1032 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %1361 = torch.aten.rsqrt %1360 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %1362 = torch.aten.mul.Tensor %1356, %1361 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_1033 = torch.constant.int 5 + %1363 = torch.prims.convert_element_type %1362, %int5_1033 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %1364 = torch.aten.mul.Tensor %54, %1363 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_1034 = torch.constant.int 5 + %1365 = torch.prims.convert_element_type %1364, %int5_1034 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_1035 = torch.constant.int -2 + %int-1_1036 = torch.constant.int -1 + %1366 = torch.aten.transpose.int %55, %int-2_1035, %int-1_1036 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_1037 = torch.constant.int 5 + %1367 = torch.prims.convert_element_type %1366, %int5_1037 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_1038 = torch.constant.int 4 + %int4096_1039 = torch.constant.int 4096 + %1368 = torch.prim.ListConstruct %int4_1038, %int4096_1039 : (!torch.int, !torch.int) -> !torch.list + %1369 = torch.aten.view %1365, %1368 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1370 = torch.aten.mm %1369, %1367 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_1040 = torch.constant.int 4 + %int1_1041 = torch.constant.int 1 + %int14336_1042 = torch.constant.int 14336 + %1371 = torch.prim.ListConstruct %int4_1040, %int1_1041, %int14336_1042 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1372 = torch.aten.view %1370, %1371 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %1373 = torch.aten.silu %1372 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_1043 = torch.constant.int -2 + %int-1_1044 = torch.constant.int -1 + %1374 = torch.aten.transpose.int %56, %int-2_1043, %int-1_1044 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_1045 = torch.constant.int 5 + %1375 = torch.prims.convert_element_type %1374, %int5_1045 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> %int4_1046 = torch.constant.int 4 - %int8_1047 = torch.constant.int 8 - %int128_1048 = torch.constant.int 128 - %1146 = torch.prim.ListConstruct %int4_1046, %1145, %int8_1047, %int128_1048 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1147 = torch.aten._unsafe_view %1143, %1146 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1147, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_1049 = torch.constant.int 0 - %int0_1050 = torch.constant.int 0 - %int9223372036854775807_1051 = torch.constant.int 9223372036854775807 - %int1_1052 = torch.constant.int 1 - %1148 = torch.aten.slice.Tensor %1147, %int0_1049, %int0_1050, %int9223372036854775807_1051, %int1_1052 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1148, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_1053 = torch.constant.int -2 - %1149 = torch.aten.unsqueeze %1138, %int-2_1053 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1149, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_1054 = torch.constant.int 1 - %1150 = torch.aten.size.int %1137, %int1_1054 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_1055 = torch.constant.int 4 - %int8_1056 = torch.constant.int 8 - %int4_1057 = torch.constant.int 4 - %int128_1058 = torch.constant.int 128 - %1151 = torch.prim.ListConstruct %int4_1055, %1150, %int8_1056, %int4_1057, %int128_1058 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1059 = torch.constant.bool false - %1152 = torch.aten.expand %1149, %1151, %false_1059 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1152, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1060 = torch.constant.int 0 - %1153 = torch.aten.clone %1152, %int0_1060 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1153, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_1061 = torch.constant.int 4 - %int32_1062 = torch.constant.int 32 - %int128_1063 = torch.constant.int 128 - %1154 = torch.prim.ListConstruct %int4_1061, %1150, %int32_1062, %int128_1063 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1155 = torch.aten._unsafe_view %1153, %1154 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1155, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_1064 = torch.constant.int -2 - %1156 = torch.aten.unsqueeze %1148, %int-2_1064 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1156, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_1065 = torch.constant.int 1 - %1157 = torch.aten.size.int %1147, %int1_1065 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_1066 = torch.constant.int 4 - %int8_1067 = torch.constant.int 8 - %int4_1068 = torch.constant.int 4 - %int128_1069 = torch.constant.int 128 - %1158 = torch.prim.ListConstruct %int4_1066, %1157, %int8_1067, %int4_1068, %int128_1069 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1070 = torch.constant.bool false - %1159 = torch.aten.expand %1156, %1158, %false_1070 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1159, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1071 = torch.constant.int 0 - %1160 = torch.aten.clone %1159, %int0_1071 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1160, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4096_1047 = torch.constant.int 4096 + %1376 = torch.prim.ListConstruct %int4_1046, %int4096_1047 : (!torch.int, !torch.int) -> !torch.list + %1377 = torch.aten.view %1365, %1376 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1378 = torch.aten.mm %1377, %1375 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_1048 = torch.constant.int 4 + %int1_1049 = torch.constant.int 1 + %int14336_1050 = torch.constant.int 14336 + %1379 = torch.prim.ListConstruct %int4_1048, %int1_1049, %int14336_1050 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1380 = torch.aten.view %1378, %1379 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %1381 = torch.aten.mul.Tensor %1373, %1380 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_1051 = torch.constant.int -2 + %int-1_1052 = torch.constant.int -1 + %1382 = torch.aten.transpose.int %57, %int-2_1051, %int-1_1052 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_1053 = torch.constant.int 5 + %1383 = torch.prims.convert_element_type %1382, %int5_1053 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_1054 = torch.constant.int 4 + %int14336_1055 = torch.constant.int 14336 + %1384 = torch.prim.ListConstruct %int4_1054, %int14336_1055 : (!torch.int, !torch.int) -> !torch.list + %1385 = torch.aten.view %1381, %1384 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %1386 = torch.aten.mm %1385, %1383 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_1056 = torch.constant.int 4 + %int1_1057 = torch.constant.int 1 + %int4096_1058 = torch.constant.int 4096 + %1387 = torch.prim.ListConstruct %int4_1056, %int1_1057, %int4096_1058 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1388 = torch.aten.view %1386, %1387 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_1059 = torch.constant.int 1 + %1389 = torch.aten.add.Tensor %1355, %1388, %int1_1059 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_1060 = torch.constant.int 6 + %1390 = torch.prims.convert_element_type %1389, %int6_1060 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_1061 = torch.constant.int 2 + %1391 = torch.aten.pow.Tensor_Scalar %1390, %int2_1061 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_1062 = torch.constant.int -1 + %1392 = torch.prim.ListConstruct %int-1_1062 : (!torch.int) -> !torch.list + %true_1063 = torch.constant.bool true + %none_1064 = torch.constant.none + %1393 = torch.aten.mean.dim %1391, %1392, %true_1063, %none_1064 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_1065 = torch.constant.float 9.9999997473787516E-6 + %int1_1066 = torch.constant.int 1 + %1394 = torch.aten.add.Scalar %1393, %float9.999990e-06_1065, %int1_1066 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %1395 = torch.aten.rsqrt %1394 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %1396 = torch.aten.mul.Tensor %1390, %1395 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_1067 = torch.constant.int 5 + %1397 = torch.prims.convert_element_type %1396, %int5_1067 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %1398 = torch.aten.mul.Tensor %58, %1397 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_1068 = torch.constant.int 5 + %1399 = torch.prims.convert_element_type %1398, %int5_1068 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_1069 = torch.constant.int -2 + %int-1_1070 = torch.constant.int -1 + %1400 = torch.aten.transpose.int %59, %int-2_1069, %int-1_1070 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_1071 = torch.constant.int 5 + %1401 = torch.prims.convert_element_type %1400, %int5_1071 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> %int4_1072 = torch.constant.int 4 - %int32_1073 = torch.constant.int 32 - %int128_1074 = torch.constant.int 128 - %1161 = torch.prim.ListConstruct %int4_1072, %1157, %int32_1073, %int128_1074 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1162 = torch.aten._unsafe_view %1160, %1161 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1162, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int4096_1073 = torch.constant.int 4096 + %1402 = torch.prim.ListConstruct %int4_1072, %int4096_1073 : (!torch.int, !torch.int) -> !torch.list + %1403 = torch.aten.view %1399, %1402 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1404 = torch.aten.mm %1403, %1401 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_1074 = torch.constant.int 4 %int1_1075 = torch.constant.int 1 - %int2_1076 = torch.constant.int 2 - %1163 = torch.aten.transpose.int %1043, %int1_1075, %int2_1076 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_1077 = torch.constant.int 1 - %int2_1078 = torch.constant.int 2 - %1164 = torch.aten.transpose.int %1155, %int1_1077, %int2_1078 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1164, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_1079 = torch.constant.int 1 - %int2_1080 = torch.constant.int 2 - %1165 = torch.aten.transpose.int %1162, %int1_1079, %int2_1080 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1165, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_1081 = torch.constant.float 0.000000e+00 - %false_1082 = torch.constant.bool false - %none_1083 = torch.constant.none - %1166:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1163, %1164, %1165, %float0.000000e00_1081, %false_1082, %368, %none_1083) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_1084 = torch.constant.int 1 - %int2_1085 = torch.constant.int 2 - %1167 = torch.aten.transpose.int %1166#0, %int1_1084, %int2_1085 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_1086 = torch.constant.int 4 - %int1_1087 = torch.constant.int 1 - %int4096_1088 = torch.constant.int 4096 - %1168 = torch.prim.ListConstruct %int4_1086, %int1_1087, %int4096_1088 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1169 = torch.aten.view %1167, %1168 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_1089 = torch.constant.int -2 - %int-1_1090 = torch.constant.int -1 - %1170 = torch.aten.transpose.int %40, %int-2_1089, %int-1_1090 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1091 = torch.constant.int 4 - %int4096_1092 = torch.constant.int 4096 - %1171 = torch.prim.ListConstruct %int4_1091, %int4096_1092 : (!torch.int, !torch.int) -> !torch.list - %1172 = torch.aten.view %1169, %1171 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1173 = torch.aten.mm %1172, %1170 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4096_1076 = torch.constant.int 4096 + %1405 = torch.prim.ListConstruct %int4_1074, %int1_1075, %int4096_1076 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1406 = torch.aten.view %1404, %1405 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_1077 = torch.constant.int -2 + %int-1_1078 = torch.constant.int -1 + %1407 = torch.aten.transpose.int %60, %int-2_1077, %int-1_1078 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_1079 = torch.constant.int 5 + %1408 = torch.prims.convert_element_type %1407, %int5_1079 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_1080 = torch.constant.int 4 + %int4096_1081 = torch.constant.int 4096 + %1409 = torch.prim.ListConstruct %int4_1080, %int4096_1081 : (!torch.int, !torch.int) -> !torch.list + %1410 = torch.aten.view %1399, %1409 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1411 = torch.aten.mm %1410, %1408 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_1082 = torch.constant.int 4 + %int1_1083 = torch.constant.int 1 + %int1024_1084 = torch.constant.int 1024 + %1412 = torch.prim.ListConstruct %int4_1082, %int1_1083, %int1024_1084 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1413 = torch.aten.view %1411, %1412 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_1085 = torch.constant.int -2 + %int-1_1086 = torch.constant.int -1 + %1414 = torch.aten.transpose.int %61, %int-2_1085, %int-1_1086 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_1087 = torch.constant.int 5 + %1415 = torch.prims.convert_element_type %1414, %int5_1087 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_1088 = torch.constant.int 4 + %int4096_1089 = torch.constant.int 4096 + %1416 = torch.prim.ListConstruct %int4_1088, %int4096_1089 : (!torch.int, !torch.int) -> !torch.list + %1417 = torch.aten.view %1399, %1416 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1418 = torch.aten.mm %1417, %1415 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_1090 = torch.constant.int 4 + %int1_1091 = torch.constant.int 1 + %int1024_1092 = torch.constant.int 1024 + %1419 = torch.prim.ListConstruct %int4_1090, %int1_1091, %int1024_1092 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1420 = torch.aten.view %1418, %1419 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> %int4_1093 = torch.constant.int 4 %int1_1094 = torch.constant.int 1 - %int4096_1095 = torch.constant.int 4096 - %1174 = torch.prim.ListConstruct %int4_1093, %int1_1094, %int4096_1095 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1175 = torch.aten.view %1173, %1174 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_1096 = torch.constant.int 1 - %1176 = torch.aten.add.Tensor %1003, %1175, %int1_1096 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_1097 = torch.constant.int 6 - %1177 = torch.prims.convert_element_type %1176, %int6_1097 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_1098 = torch.constant.int 2 - %1178 = torch.aten.pow.Tensor_Scalar %1177, %int2_1098 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_1099 = torch.constant.int -1 - %1179 = torch.prim.ListConstruct %int-1_1099 : (!torch.int) -> !torch.list - %true_1100 = torch.constant.bool true - %none_1101 = torch.constant.none - %1180 = torch.aten.mean.dim %1178, %1179, %true_1100, %none_1101 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_1102 = torch.constant.float 9.9999997473787516E-6 - %int1_1103 = torch.constant.int 1 - %1181 = torch.aten.add.Scalar %1180, %float9.999990e-06_1102, %int1_1103 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %1182 = torch.aten.rsqrt %1181 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %1183 = torch.aten.mul.Tensor %1177, %1182 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_1104 = torch.constant.int 5 - %1184 = torch.prims.convert_element_type %1183, %int5_1104 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %1185 = torch.aten.mul.Tensor %41, %1184 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_1105 = torch.constant.int 5 - %1186 = torch.prims.convert_element_type %1185, %int5_1105 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_1106 = torch.constant.int -2 - %int-1_1107 = torch.constant.int -1 - %1187 = torch.aten.transpose.int %42, %int-2_1106, %int-1_1107 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1108 = torch.constant.int 4 - %int4096_1109 = torch.constant.int 4096 - %1188 = torch.prim.ListConstruct %int4_1108, %int4096_1109 : (!torch.int, !torch.int) -> !torch.list - %1189 = torch.aten.view %1186, %1188 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1190 = torch.aten.mm %1189, %1187 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_1110 = torch.constant.int 4 - %int1_1111 = torch.constant.int 1 - %int14336_1112 = torch.constant.int 14336 - %1191 = torch.prim.ListConstruct %int4_1110, %int1_1111, %int14336_1112 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1192 = torch.aten.view %1190, %1191 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %1193 = torch.aten.silu %1192 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_1113 = torch.constant.int -2 - %int-1_1114 = torch.constant.int -1 - %1194 = torch.aten.transpose.int %43, %int-2_1113, %int-1_1114 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1115 = torch.constant.int 4 - %int4096_1116 = torch.constant.int 4096 - %1195 = torch.prim.ListConstruct %int4_1115, %int4096_1116 : (!torch.int, !torch.int) -> !torch.list - %1196 = torch.aten.view %1186, %1195 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1197 = torch.aten.mm %1196, %1194 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_1117 = torch.constant.int 4 - %int1_1118 = torch.constant.int 1 - %int14336_1119 = torch.constant.int 14336 - %1198 = torch.prim.ListConstruct %int4_1117, %int1_1118, %int14336_1119 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1199 = torch.aten.view %1197, %1198 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %1200 = torch.aten.mul.Tensor %1193, %1199 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_1120 = torch.constant.int -2 - %int-1_1121 = torch.constant.int -1 - %1201 = torch.aten.transpose.int %44, %int-2_1120, %int-1_1121 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_1122 = torch.constant.int 4 - %int14336_1123 = torch.constant.int 14336 - %1202 = torch.prim.ListConstruct %int4_1122, %int14336_1123 : (!torch.int, !torch.int) -> !torch.list - %1203 = torch.aten.view %1200, %1202 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %1204 = torch.aten.mm %1203, %1201 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_1124 = torch.constant.int 4 - %int1_1125 = torch.constant.int 1 - %int4096_1126 = torch.constant.int 4096 - %1205 = torch.prim.ListConstruct %int4_1124, %int1_1125, %int4096_1126 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1206 = torch.aten.view %1204, %1205 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_1127 = torch.constant.int 1 - %1207 = torch.aten.add.Tensor %1176, %1206, %int1_1127 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_1128 = torch.constant.int 6 - %1208 = torch.prims.convert_element_type %1207, %int6_1128 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_1129 = torch.constant.int 2 - %1209 = torch.aten.pow.Tensor_Scalar %1208, %int2_1129 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_1130 = torch.constant.int -1 - %1210 = torch.prim.ListConstruct %int-1_1130 : (!torch.int) -> !torch.list - %true_1131 = torch.constant.bool true - %none_1132 = torch.constant.none - %1211 = torch.aten.mean.dim %1209, %1210, %true_1131, %none_1132 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_1133 = torch.constant.float 9.9999997473787516E-6 + %int32_1095 = torch.constant.int 32 + %int128_1096 = torch.constant.int 128 + %1421 = torch.prim.ListConstruct %int4_1093, %int1_1094, %int32_1095, %int128_1096 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1422 = torch.aten.view %1406, %1421 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_1097 = torch.constant.int 4 + %int1_1098 = torch.constant.int 1 + %int8_1099 = torch.constant.int 8 + %int128_1100 = torch.constant.int 128 + %1423 = torch.prim.ListConstruct %int4_1097, %int1_1098, %int8_1099, %int128_1100 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1424 = torch.aten.view %1413, %1423 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_1101 = torch.constant.int 4 + %int1_1102 = torch.constant.int 1 + %int8_1103 = torch.constant.int 8 + %int128_1104 = torch.constant.int 128 + %1425 = torch.prim.ListConstruct %int4_1101, %int1_1102, %int8_1103, %int128_1104 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1426 = torch.aten.view %1420, %1425 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_1105 = torch.constant.int 1 + %int2_1106 = torch.constant.int 2 + %1427 = torch.aten.transpose.int %1422, %int1_1105, %int2_1106 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %1428 = torch.aten.mul.Tensor %1427, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_1107 = torch.constant.int 3 + %int0_1108 = torch.constant.int 0 + %int64_1109 = torch.constant.int 64 + %int1_1110 = torch.constant.int 1 + %1429 = torch.aten.slice.Tensor %1427, %int3_1107, %int0_1108, %int64_1109, %int1_1110 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_1111 = torch.constant.int 3 + %int64_1112 = torch.constant.int 64 + %int9223372036854775807_1113 = torch.constant.int 9223372036854775807 + %int1_1114 = torch.constant.int 1 + %1430 = torch.aten.slice.Tensor %1427, %int3_1111, %int64_1112, %int9223372036854775807_1113, %int1_1114 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %1431 = torch.aten.neg %1430 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %1432 = torch.prim.ListConstruct %1431, %1429 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_1115 = torch.constant.int -1 + %1433 = torch.aten.cat %1432, %int-1_1115 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %1434 = torch.aten.mul.Tensor %1433, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_1116 = torch.constant.int 1 + %1435 = torch.aten.add.Tensor %1428, %1434, %int1_1116 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_1117 = torch.constant.int 1 + %int2_1118 = torch.constant.int 2 + %1436 = torch.aten.transpose.int %1435, %int1_1117, %int2_1118 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_1119 = torch.constant.int 1 + %int2_1120 = torch.constant.int 2 + %1437 = torch.aten.transpose.int %1424, %int1_1119, %int2_1120 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %1438 = torch.aten.mul.Tensor %1437, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_1121 = torch.constant.int 3 + %int0_1122 = torch.constant.int 0 + %int64_1123 = torch.constant.int 64 + %int1_1124 = torch.constant.int 1 + %1439 = torch.aten.slice.Tensor %1437, %int3_1121, %int0_1122, %int64_1123, %int1_1124 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_1125 = torch.constant.int 3 + %int64_1126 = torch.constant.int 64 + %int9223372036854775807_1127 = torch.constant.int 9223372036854775807 + %int1_1128 = torch.constant.int 1 + %1440 = torch.aten.slice.Tensor %1437, %int3_1125, %int64_1126, %int9223372036854775807_1127, %int1_1128 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %1441 = torch.aten.neg %1440 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %1442 = torch.prim.ListConstruct %1441, %1439 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_1129 = torch.constant.int -1 + %1443 = torch.aten.cat %1442, %int-1_1129 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %1444 = torch.aten.mul.Tensor %1443, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_1130 = torch.constant.int 1 + %1445 = torch.aten.add.Tensor %1438, %1444, %int1_1130 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_1131 = torch.constant.int 1 + %int2_1132 = torch.constant.int 2 + %1446 = torch.aten.transpose.int %1445, %int1_1131, %int2_1132 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_1133 = torch.constant.int 32 + %1447 = torch.aten.floor_divide.Scalar %arg2, %int32_1133 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> %int1_1134 = torch.constant.int 1 - %1212 = torch.aten.add.Scalar %1211, %float9.999990e-06_1133, %int1_1134 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %1213 = torch.aten.rsqrt %1212 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %1214 = torch.aten.mul.Tensor %1208, %1213 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_1135 = torch.constant.int 5 - %1215 = torch.prims.convert_element_type %1214, %int5_1135 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %1216 = torch.aten.mul.Tensor %45, %1215 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_1136 = torch.constant.int 5 - %1217 = torch.prims.convert_element_type %1216, %int5_1136 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_1137 = torch.constant.int -2 - %int-1_1138 = torch.constant.int -1 - %1218 = torch.aten.transpose.int %46, %int-2_1137, %int-1_1138 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1139 = torch.constant.int 4 - %int4096_1140 = torch.constant.int 4096 - %1219 = torch.prim.ListConstruct %int4_1139, %int4096_1140 : (!torch.int, !torch.int) -> !torch.list - %1220 = torch.aten.view %1217, %1219 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1221 = torch.aten.mm %1220, %1218 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %1448 = torch.aten.unsqueeze %1447, %int1_1134 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_1135 = torch.constant.int 1 + %false_1136 = torch.constant.bool false + %1449 = torch.aten.gather %arg3, %int1_1135, %1448, %false_1136 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_1137 = torch.constant.int 4 + %int1_1138 = torch.constant.int 1 + %int1_1139 = torch.constant.int 1 + %1450 = torch.prim.ListConstruct %int4_1137, %int1_1138, %int1_1139 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1451 = torch.aten.view %1449, %1450 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_1140 = torch.constant.int 32 + %1452 = torch.aten.remainder.Scalar %arg2, %int32_1140 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> %int4_1141 = torch.constant.int 4 %int1_1142 = torch.constant.int 1 - %int4096_1143 = torch.constant.int 4096 - %1222 = torch.prim.ListConstruct %int4_1141, %int1_1142, %int4096_1143 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1223 = torch.aten.view %1221, %1222 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_1144 = torch.constant.int -2 - %int-1_1145 = torch.constant.int -1 - %1224 = torch.aten.transpose.int %47, %int-2_1144, %int-1_1145 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_1146 = torch.constant.int 4 - %int4096_1147 = torch.constant.int 4096 - %1225 = torch.prim.ListConstruct %int4_1146, %int4096_1147 : (!torch.int, !torch.int) -> !torch.list - %1226 = torch.aten.view %1217, %1225 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1227 = torch.aten.mm %1226, %1224 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_1148 = torch.constant.int 4 + %int1_1143 = torch.constant.int 1 + %1453 = torch.prim.ListConstruct %int4_1141, %int1_1142, %int1_1143 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1454 = torch.aten.view %1452, %1453 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_1144 = torch.constant.int 8 + %none_1145 = torch.constant.none + %none_1146 = torch.constant.none + %cpu_1147 = torch.constant.device "cpu" + %false_1148 = torch.constant.bool false + %1455 = torch.aten.arange %int8_1144, %none_1145, %none_1146, %cpu_1147, %false_1148 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> %int1_1149 = torch.constant.int 1 - %int1024_1150 = torch.constant.int 1024 - %1228 = torch.prim.ListConstruct %int4_1148, %int1_1149, %int1024_1150 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1229 = torch.aten.view %1227, %1228 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_1151 = torch.constant.int -2 - %int-1_1152 = torch.constant.int -1 - %1230 = torch.aten.transpose.int %48, %int-2_1151, %int-1_1152 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_1153 = torch.constant.int 4 - %int4096_1154 = torch.constant.int 4096 - %1231 = torch.prim.ListConstruct %int4_1153, %int4096_1154 : (!torch.int, !torch.int) -> !torch.list - %1232 = torch.aten.view %1217, %1231 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1233 = torch.aten.mm %1232, %1230 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_1155 = torch.constant.int 4 - %int1_1156 = torch.constant.int 1 - %int1024_1157 = torch.constant.int 1024 - %1234 = torch.prim.ListConstruct %int4_1155, %int1_1156, %int1024_1157 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1235 = torch.aten.view %1233, %1234 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_1158 = torch.constant.int 4 - %int1_1159 = torch.constant.int 1 - %int32_1160 = torch.constant.int 32 - %int128_1161 = torch.constant.int 128 - %1236 = torch.prim.ListConstruct %int4_1158, %int1_1159, %int32_1160, %int128_1161 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1237 = torch.aten.view %1223, %1236 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_1162 = torch.constant.int 4 - %int1_1163 = torch.constant.int 1 - %int8_1164 = torch.constant.int 8 - %int128_1165 = torch.constant.int 128 - %1238 = torch.prim.ListConstruct %int4_1162, %int1_1163, %int8_1164, %int128_1165 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1239 = torch.aten.view %1229, %1238 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_1166 = torch.constant.int 4 - %int1_1167 = torch.constant.int 1 + %int1_1150 = torch.constant.int 1 + %int8_1151 = torch.constant.int 8 + %1456 = torch.prim.ListConstruct %int1_1149, %int1_1150, %int8_1151 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1457 = torch.aten.view %1455, %1456 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_1152 = torch.constant.none + %1458 = torch.aten.clone %62, %none_1152 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1459 = torch.aten.detach %1458 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1460 = torch.aten.detach %1459 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1461 = torch.aten.detach %1460 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_1153 = torch.constant.int 1 + %int1_1154 = torch.constant.int 1 + %int1_1155 = torch.constant.int 1 + %1462 = torch.prim.ListConstruct %int1_1153, %int1_1154, %int1_1155 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1463 = torch.aten.view %1461, %1462 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_1156 = torch.constant.int 32 + %1464 = torch.aten.mul.Scalar %1451, %int32_1156 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int4_1157 = torch.constant.int 4 + %int1_1158 = torch.constant.int 1 + %1465 = torch.aten.add.Scalar %1464, %int4_1157, %int1_1158 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_1159 = torch.constant.int 2 + %1466 = torch.aten.mul.Scalar %1465, %int2_1159 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_1160 = torch.constant.int 1 + %1467 = torch.aten.add.Tensor %1466, %1463, %int1_1160 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_1161 = torch.constant.int 8 + %1468 = torch.aten.mul.Scalar %1467, %int8_1161 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_1162 = torch.constant.int 1 + %1469 = torch.aten.add.Tensor %1468, %1457, %int1_1162 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_1163 = torch.constant.int 32 + %1470 = torch.aten.mul.Scalar %1469, %int32_1163 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_1164 = torch.constant.int 1 + %1471 = torch.aten.add.Tensor %1470, %1454, %int1_1164 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_1165 = torch.constant.int 5 + %1472 = torch.prims.convert_element_type %1446, %int5_1165 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_1166 = torch.constant.int 32 + %int2_1167 = torch.constant.int 2 %int8_1168 = torch.constant.int 8 - %int128_1169 = torch.constant.int 128 - %1240 = torch.prim.ListConstruct %int4_1166, %int1_1167, %int8_1168, %int128_1169 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1241 = torch.aten.view %1235, %1240 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_1170 = torch.constant.int 6 - %1242 = torch.prims.convert_element_type %1237, %int6_1170 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %1243 = torch_c.to_builtin_tensor %1242 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %1244 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %1245 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%1243, %1244) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %1246 = torch_c.from_builtin_tensor %1245 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_1171 = torch.constant.int 5 - %1247 = torch.prims.convert_element_type %1246, %int5_1171 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_1172 = torch.constant.int 6 - %1248 = torch.prims.convert_element_type %1239, %int6_1172 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %1249 = torch_c.to_builtin_tensor %1248 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %1250 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %1251 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%1249, %1250) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %1252 = torch_c.from_builtin_tensor %1251 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_1173 = torch.constant.int 5 - %1253 = torch.prims.convert_element_type %1252, %int5_1173 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_1174 = torch.constant.int 32 - %1254 = torch.aten.floor_divide.Scalar %arg2, %int32_1174 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_1175 = torch.constant.int 1 - %1255 = torch.aten.unsqueeze %1254, %int1_1175 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1176 = torch.constant.int 1 - %false_1177 = torch.constant.bool false - %1256 = torch.aten.gather %arg3, %int1_1176, %1255, %false_1177 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_1178 = torch.constant.int 32 - %1257 = torch.aten.remainder.Scalar %arg2, %int32_1178 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_1179 = torch.constant.int 1 - %1258 = torch.aten.unsqueeze %1257, %int1_1179 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_1180 = torch.constant.none - %1259 = torch.aten.clone %49, %none_1180 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_1181 = torch.constant.int 0 - %1260 = torch.aten.unsqueeze %1259, %int0_1181 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_1182 = torch.constant.int 4 - %int1_1183 = torch.constant.int 1 - %1261 = torch.prim.ListConstruct %int4_1182, %int1_1183 : (!torch.int, !torch.int) -> !torch.list - %int1_1184 = torch.constant.int 1 - %int1_1185 = torch.constant.int 1 - %1262 = torch.prim.ListConstruct %int1_1184, %int1_1185 : (!torch.int, !torch.int) -> !torch.list - %int4_1186 = torch.constant.int 4 - %int0_1187 = torch.constant.int 0 - %cpu_1188 = torch.constant.device "cpu" - %false_1189 = torch.constant.bool false - %1263 = torch.aten.empty_strided %1261, %1262, %int4_1186, %int0_1187, %cpu_1188, %false_1189 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> + %int32_1169 = torch.constant.int 32 + %int128_1170 = torch.constant.int 128 + %1473 = torch.prim.ListConstruct %456, %int32_1166, %int2_1167, %int8_1168, %int32_1169, %int128_1170 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1474 = torch.aten.view %1294, %1473 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1474, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_1171 = torch.constant.int 128 + %1475 = torch.prim.ListConstruct %596, %int128_1171 : (!torch.int, !torch.int) -> !torch.list + %1476 = torch.aten.view %1474, %1475 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1476, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %1477 = torch.prim.ListConstruct %1471 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_1172 = torch.constant.bool false + %1478 = torch.aten.index_put %1476, %1477, %1472, %false_1172 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1478, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_1173 = torch.constant.int 32 + %int2_1174 = torch.constant.int 2 + %int8_1175 = torch.constant.int 8 + %int32_1176 = torch.constant.int 32 + %int128_1177 = torch.constant.int 128 + %1479 = torch.prim.ListConstruct %456, %int32_1173, %int2_1174, %int8_1175, %int32_1176, %int128_1177 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1480 = torch.aten.view %1478, %1479 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1480, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_1178 = torch.constant.int 2097152 + %1481 = torch.prim.ListConstruct %456, %int2097152_1178 : (!torch.int, !torch.int) -> !torch.list + %1482 = torch.aten.view %1480, %1481 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1482, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_1179 = torch.constant.int 32 + %int2_1180 = torch.constant.int 2 + %int8_1181 = torch.constant.int 8 + %int32_1182 = torch.constant.int 32 + %int128_1183 = torch.constant.int 128 + %1483 = torch.prim.ListConstruct %456, %int32_1179, %int2_1180, %int8_1181, %int32_1182, %int128_1183 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1484 = torch.aten.view %1482, %1483 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1484, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_1184 = torch.constant.int 128 + %1485 = torch.prim.ListConstruct %596, %int128_1184 : (!torch.int, !torch.int) -> !torch.list + %1486 = torch.aten.view %1484, %1485 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1486, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_1185 = torch.constant.none + %1487 = torch.aten.clone %63, %none_1185 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1488 = torch.aten.detach %1487 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1489 = torch.aten.detach %1488 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1490 = torch.aten.detach %1489 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_1186 = torch.constant.int 1 + %int1_1187 = torch.constant.int 1 + %int1_1188 = torch.constant.int 1 + %1491 = torch.prim.ListConstruct %int1_1186, %int1_1187, %int1_1188 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1492 = torch.aten.view %1490, %1491 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_1189 = torch.constant.int 32 + %1493 = torch.aten.mul.Scalar %1451, %int32_1189 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int4_1190 = torch.constant.int 4 - %1264 = torch.aten.fill.Scalar %1263, %int4_1190 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_1191 = torch.constant.int 4 - %int1_1192 = torch.constant.int 1 - %1265 = torch.prim.ListConstruct %int4_1191, %int1_1192 : (!torch.int, !torch.int) -> !torch.list - %1266 = torch.aten.repeat %1260, %1265 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_1193 = torch.constant.int 32 - %1267 = torch.aten.mul.Scalar %1256, %int32_1193 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1194 = torch.constant.int 1 - %1268 = torch.aten.add.Tensor %1267, %1264, %int1_1194 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_1195 = torch.constant.int 2 - %1269 = torch.aten.mul.Scalar %1268, %int2_1195 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1196 = torch.constant.int 1 - %1270 = torch.aten.add.Tensor %1269, %1266, %int1_1196 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_1197 = torch.constant.int 32 - %1271 = torch.aten.mul.Scalar %1270, %int32_1197 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1198 = torch.constant.int 1 - %1272 = torch.aten.add.Tensor %1271, %1258, %int1_1198 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_1199 = torch.constant.int 32 - %int2_1200 = torch.constant.int 2 - %int32_1201 = torch.constant.int 32 + %int1_1191 = torch.constant.int 1 + %1494 = torch.aten.add.Scalar %1493, %int4_1190, %int1_1191 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_1192 = torch.constant.int 2 + %1495 = torch.aten.mul.Scalar %1494, %int2_1192 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_1193 = torch.constant.int 1 + %1496 = torch.aten.add.Tensor %1495, %1492, %int1_1193 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_1194 = torch.constant.int 8 + %1497 = torch.aten.mul.Scalar %1496, %int8_1194 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_1195 = torch.constant.int 1 + %1498 = torch.aten.add.Tensor %1497, %1457, %int1_1195 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_1196 = torch.constant.int 32 + %1499 = torch.aten.mul.Scalar %1498, %int32_1196 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_1197 = torch.constant.int 1 + %1500 = torch.aten.add.Tensor %1499, %1454, %int1_1197 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_1198 = torch.constant.int 5 + %1501 = torch.prims.convert_element_type %1426, %int5_1198 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %1502 = torch.prim.ListConstruct %1500 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_1199 = torch.constant.bool false + %1503 = torch.aten.index_put %1486, %1502, %1501, %false_1199 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1503, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_1200 = torch.constant.int 32 + %int2_1201 = torch.constant.int 2 %int8_1202 = torch.constant.int 8 - %int128_1203 = torch.constant.int 128 - %1273 = torch.prim.ListConstruct %437, %int32_1199, %int2_1200, %int32_1201, %int8_1202, %int128_1203 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1274 = torch.aten.view %1110, %1273 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1274, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_1204 = torch.constant.int 32 - %1275 = torch.aten.mul.int %437, %int32_1204 : !torch.int, !torch.int -> !torch.int - %int2_1205 = torch.constant.int 2 - %1276 = torch.aten.mul.int %1275, %int2_1205 : !torch.int, !torch.int -> !torch.int - %int32_1206 = torch.constant.int 32 - %1277 = torch.aten.mul.int %1276, %int32_1206 : !torch.int, !torch.int -> !torch.int - %int8_1207 = torch.constant.int 8 - %int128_1208 = torch.constant.int 128 - %1278 = torch.prim.ListConstruct %1277, %int8_1207, %int128_1208 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1279 = torch.aten.view %1274, %1278 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1279, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %1280 = torch.prim.ListConstruct %1272 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_1209 = torch.constant.bool false - %1281 = torch.aten.index_put %1279, %1280, %1253, %false_1209 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1281, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_1210 = torch.constant.int 32 - %int2_1211 = torch.constant.int 2 + %int32_1203 = torch.constant.int 32 + %int128_1204 = torch.constant.int 128 + %1504 = torch.prim.ListConstruct %456, %int32_1200, %int2_1201, %int8_1202, %int32_1203, %int128_1204 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1505 = torch.aten.view %1503, %1504 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1505, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_1205 = torch.constant.int 2097152 + %1506 = torch.prim.ListConstruct %456, %int2097152_1205 : (!torch.int, !torch.int) -> !torch.list + %1507 = torch.aten.view %1505, %1506 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1507, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_1206 = torch.constant.none + %1508 = torch.aten.clone %64, %none_1206 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1509 = torch.aten.detach %1508 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1510 = torch.aten.detach %1509 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1511 = torch.aten.detach %1510 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_1207 = torch.constant.none + %1512 = torch.aten.clone %65, %none_1207 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1513 = torch.aten.detach %1512 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1514 = torch.aten.detach %1513 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1515 = torch.aten.detach %1514 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_1208 = torch.constant.none + %1516 = torch.aten.clone %66, %none_1208 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1517 = torch.aten.detach %1516 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1518 = torch.aten.detach %1517 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1519 = torch.aten.detach %1518 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_1209 = torch.constant.int 32 + %int2_1210 = torch.constant.int 2 + %int8_1211 = torch.constant.int 8 %int32_1212 = torch.constant.int 32 - %int8_1213 = torch.constant.int 8 - %int128_1214 = torch.constant.int 128 - %1282 = torch.prim.ListConstruct %437, %int32_1210, %int2_1211, %int32_1212, %int8_1213, %int128_1214 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1283 = torch.aten.view %1281, %1282 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1283, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_1215 = torch.constant.int 2097152 - %1284 = torch.prim.ListConstruct %437, %int2097152_1215 : (!torch.int, !torch.int) -> !torch.list - %1285 = torch.aten.view %1283, %1284 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1285, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_1216 = torch.constant.int 32 - %int2_1217 = torch.constant.int 2 - %int32_1218 = torch.constant.int 32 - %int8_1219 = torch.constant.int 8 - %int128_1220 = torch.constant.int 128 - %1286 = torch.prim.ListConstruct %437, %int32_1216, %int2_1217, %int32_1218, %int8_1219, %int128_1220 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1287 = torch.aten.view %1285, %1286 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1287, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_1221 = torch.constant.int 8 - %int128_1222 = torch.constant.int 128 - %1288 = torch.prim.ListConstruct %1277, %int8_1221, %int128_1222 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1289 = torch.aten.view %1287, %1288 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1289, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_1223 = torch.constant.int 32 - %1290 = torch.aten.floor_divide.Scalar %arg2, %int32_1223 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_1224 = torch.constant.int 1 - %1291 = torch.aten.unsqueeze %1290, %int1_1224 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1225 = torch.constant.int 1 - %false_1226 = torch.constant.bool false - %1292 = torch.aten.gather %arg3, %int1_1225, %1291, %false_1226 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_1227 = torch.constant.int 32 - %1293 = torch.aten.remainder.Scalar %arg2, %int32_1227 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_1228 = torch.constant.int 1 - %1294 = torch.aten.unsqueeze %1293, %int1_1228 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_1229 = torch.constant.none - %1295 = torch.aten.clone %50, %none_1229 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_1230 = torch.constant.int 0 - %1296 = torch.aten.unsqueeze %1295, %int0_1230 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %int128_1213 = torch.constant.int 128 + %1520 = torch.prim.ListConstruct %456, %int32_1209, %int2_1210, %int8_1211, %int32_1212, %int128_1213 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1521 = torch.aten.view %1507, %1520 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1521, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %1522 = torch_c.to_builtin_tensor %1521 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %1523 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_1214 = tensor.cast %1523 : tensor<4x?xi64> to tensor + %1524 = torch_c.to_builtin_tensor %1511 : !torch.vtensor<[],si64> -> tensor + %1525 = torch_c.to_builtin_tensor %1515 : !torch.vtensor<[],si64> -> tensor + %1526 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%1522, %cast_1214, %1524, %1525) : (tensor, tensor, tensor, tensor) -> tensor + %cast_1215 = tensor.cast %1526 : tensor to tensor<4x?x8x32x128xf16> + %1527 = torch_c.from_builtin_tensor %cast_1215 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %1527, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %1528 = torch_c.to_builtin_tensor %1521 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %1529 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_1216 = tensor.cast %1529 : tensor<4x?xi64> to tensor + %1530 = torch_c.to_builtin_tensor %1511 : !torch.vtensor<[],si64> -> tensor + %1531 = torch_c.to_builtin_tensor %1519 : !torch.vtensor<[],si64> -> tensor + %1532 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%1528, %cast_1216, %1530, %1531) : (tensor, tensor, tensor, tensor) -> tensor + %cast_1217 = tensor.cast %1532 : tensor to tensor<4x?x8x32x128xf16> + %1533 = torch_c.from_builtin_tensor %cast_1217 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %1533, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_1218 = torch.constant.int 2 + %int3_1219 = torch.constant.int 3 + %1534 = torch.aten.transpose.int %1527, %int2_1218, %int3_1219 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1534, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_1220 = torch.constant.int 0 + %1535 = torch.aten.clone %1534, %int0_1220 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1535, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_1221 = torch.constant.int 4 + %int8_1222 = torch.constant.int 8 + %int128_1223 = torch.constant.int 128 + %1536 = torch.prim.ListConstruct %int4_1221, %457, %int8_1222, %int128_1223 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1537 = torch.aten._unsafe_view %1535, %1536 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1537, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_1224 = torch.constant.int 2 + %int3_1225 = torch.constant.int 3 + %1538 = torch.aten.transpose.int %1533, %int2_1224, %int3_1225 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1538, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_1226 = torch.constant.int 0 + %1539 = torch.aten.clone %1538, %int0_1226 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1539, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_1227 = torch.constant.int 4 + %int8_1228 = torch.constant.int 8 + %int128_1229 = torch.constant.int 128 + %1540 = torch.prim.ListConstruct %int4_1227, %457, %int8_1228, %int128_1229 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1541 = torch.aten._unsafe_view %1539, %1540 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1541, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_1230 = torch.constant.int -2 + %1542 = torch.aten.unsqueeze %1537, %int-2_1230 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1542, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> %int4_1231 = torch.constant.int 4 - %int1_1232 = torch.constant.int 1 - %1297 = torch.prim.ListConstruct %int4_1231, %int1_1232 : (!torch.int, !torch.int) -> !torch.list - %int1_1233 = torch.constant.int 1 - %int1_1234 = torch.constant.int 1 - %1298 = torch.prim.ListConstruct %int1_1233, %int1_1234 : (!torch.int, !torch.int) -> !torch.list - %int4_1235 = torch.constant.int 4 + %int8_1232 = torch.constant.int 8 + %int4_1233 = torch.constant.int 4 + %int128_1234 = torch.constant.int 128 + %1543 = torch.prim.ListConstruct %int4_1231, %457, %int8_1232, %int4_1233, %int128_1234 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_1235 = torch.constant.bool false + %1544 = torch.aten.expand %1542, %1543, %false_1235 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1544, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int0_1236 = torch.constant.int 0 - %cpu_1237 = torch.constant.device "cpu" - %false_1238 = torch.constant.bool false - %1299 = torch.aten.empty_strided %1297, %1298, %int4_1235, %int0_1236, %cpu_1237, %false_1238 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int4_1239 = torch.constant.int 4 - %1300 = torch.aten.fill.Scalar %1299, %int4_1239 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_1240 = torch.constant.int 4 - %int1_1241 = torch.constant.int 1 - %1301 = torch.prim.ListConstruct %int4_1240, %int1_1241 : (!torch.int, !torch.int) -> !torch.list - %1302 = torch.aten.repeat %1296, %1301 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_1242 = torch.constant.int 32 - %1303 = torch.aten.mul.Scalar %1292, %int32_1242 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1243 = torch.constant.int 1 - %1304 = torch.aten.add.Tensor %1303, %1300, %int1_1243 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_1244 = torch.constant.int 2 - %1305 = torch.aten.mul.Scalar %1304, %int2_1244 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1245 = torch.constant.int 1 - %1306 = torch.aten.add.Tensor %1305, %1302, %int1_1245 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_1246 = torch.constant.int 32 - %1307 = torch.aten.mul.Scalar %1306, %int32_1246 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1247 = torch.constant.int 1 - %1308 = torch.aten.add.Tensor %1307, %1294, %int1_1247 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %1309 = torch.prim.ListConstruct %1308 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_1248 = torch.constant.bool false - %1310 = torch.aten.index_put %1289, %1309, %1241, %false_1248 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1310, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_1249 = torch.constant.int 32 - %int2_1250 = torch.constant.int 2 - %int32_1251 = torch.constant.int 32 - %int8_1252 = torch.constant.int 8 - %int128_1253 = torch.constant.int 128 - %1311 = torch.prim.ListConstruct %437, %int32_1249, %int2_1250, %int32_1251, %int8_1252, %int128_1253 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1312 = torch.aten.view %1310, %1311 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1312, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_1254 = torch.constant.int 2097152 - %1313 = torch.prim.ListConstruct %437, %int2097152_1254 : (!torch.int, !torch.int) -> !torch.list - %1314 = torch.aten.view %1312, %1313 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1314, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_1255 = torch.constant.int 4 - %1315 = torch.prim.ListConstruct %int4_1255, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_1256 = torch.constant.int 1 - %1316 = torch.prim.ListConstruct %358, %int1_1256 : (!torch.int, !torch.int) -> !torch.list - %int4_1257 = torch.constant.int 4 - %int0_1258 = torch.constant.int 0 - %cpu_1259 = torch.constant.device "cpu" - %false_1260 = torch.constant.bool false - %1317 = torch.aten.empty_strided %1315, %1316, %int4_1257, %int0_1258, %cpu_1259, %false_1260 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1317, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %1545 = torch.aten.clone %1544, %int0_1236 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1545, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_1237 = torch.constant.int 4 + %int32_1238 = torch.constant.int 32 + %int128_1239 = torch.constant.int 128 + %1546 = torch.prim.ListConstruct %int4_1237, %457, %int32_1238, %int128_1239 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1547 = torch.aten._unsafe_view %1545, %1546 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1547, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_1240 = torch.constant.int -2 + %1548 = torch.aten.unsqueeze %1541, %int-2_1240 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1548, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_1241 = torch.constant.int 4 + %int8_1242 = torch.constant.int 8 + %int4_1243 = torch.constant.int 4 + %int128_1244 = torch.constant.int 128 + %1549 = torch.prim.ListConstruct %int4_1241, %457, %int8_1242, %int4_1243, %int128_1244 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_1245 = torch.constant.bool false + %1550 = torch.aten.expand %1548, %1549, %false_1245 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1550, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_1246 = torch.constant.int 0 + %1551 = torch.aten.clone %1550, %int0_1246 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1551, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_1247 = torch.constant.int 4 + %int32_1248 = torch.constant.int 32 + %int128_1249 = torch.constant.int 128 + %1552 = torch.prim.ListConstruct %int4_1247, %457, %int32_1248, %int128_1249 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1553 = torch.aten._unsafe_view %1551, %1552 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1553, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_1250 = torch.constant.int 1 + %int2_1251 = torch.constant.int 2 + %1554 = torch.aten.transpose.int %1436, %int1_1250, %int2_1251 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_1252 = torch.constant.int 1 + %int2_1253 = torch.constant.int 2 + %1555 = torch.aten.transpose.int %1547, %int1_1252, %int2_1253 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1555, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_1254 = torch.constant.int 1 + %int2_1255 = torch.constant.int 2 + %1556 = torch.aten.transpose.int %1553, %int1_1254, %int2_1255 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1556, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_1256 = torch.constant.float 0.000000e+00 + %false_1257 = torch.constant.bool false + %none_1258 = torch.constant.none + %1557:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1554, %1555, %1556, %float0.000000e00_1256, %false_1257, %470, %none_1258) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_1259 = torch.constant.int 1 + %int2_1260 = torch.constant.int 2 + %1558 = torch.aten.transpose.int %1557#0, %int1_1259, %int2_1260 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> %int4_1261 = torch.constant.int 4 - %1318 = torch.aten.fill.Scalar %1317, %int4_1261 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1318, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_1262 = torch.constant.int 32 - %1319 = torch.aten.mul.Scalar %arg3, %int32_1262 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1319, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_1263 = torch.constant.int 1 - %1320 = torch.aten.add.Tensor %1319, %1318, %int1_1263 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1320, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_1264 = torch.constant.int 4 - %1321 = torch.aten.mul.int %int4_1264, %358 : !torch.int, !torch.int -> !torch.int - %1322 = torch.prim.ListConstruct %1321 : (!torch.int) -> !torch.list - %1323 = torch.aten.view %1320, %1322 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1323, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_1265 = torch.constant.int 32 - %int2_1266 = torch.constant.int 2 - %int32_1267 = torch.constant.int 32 - %int8_1268 = torch.constant.int 8 - %int128_1269 = torch.constant.int 128 - %1324 = torch.prim.ListConstruct %437, %int32_1265, %int2_1266, %int32_1267, %int8_1268, %int128_1269 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1325 = torch.aten.view %1314, %1324 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1325, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_1270 = torch.constant.int 32 - %1326 = torch.aten.mul.int %437, %int32_1270 : !torch.int, !torch.int -> !torch.int - %int2_1271 = torch.constant.int 2 - %int32_1272 = torch.constant.int 32 - %int8_1273 = torch.constant.int 8 - %int128_1274 = torch.constant.int 128 - %1327 = torch.prim.ListConstruct %1326, %int2_1271, %int32_1272, %int8_1273, %int128_1274 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1328 = torch.aten.view %1325, %1327 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %1328, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_1275 = torch.constant.int 0 - %1329 = torch.aten.index_select %1328, %int0_1275, %1323 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %1329, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_1276 = torch.constant.int 4 - %int2_1277 = torch.constant.int 2 - %int32_1278 = torch.constant.int 32 - %int8_1279 = torch.constant.int 8 - %int128_1280 = torch.constant.int 128 - %1330 = torch.prim.ListConstruct %int4_1276, %358, %int2_1277, %int32_1278, %int8_1279, %int128_1280 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1331 = torch.aten.view %1329, %1330 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1331, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_1281 = torch.constant.int 0 - %int0_1282 = torch.constant.int 0 - %int9223372036854775807_1283 = torch.constant.int 9223372036854775807 - %int1_1284 = torch.constant.int 1 - %1332 = torch.aten.slice.Tensor %1331, %int0_1281, %int0_1282, %int9223372036854775807_1283, %int1_1284 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1332, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_1285 = torch.constant.int 1 - %int0_1286 = torch.constant.int 0 - %int9223372036854775807_1287 = torch.constant.int 9223372036854775807 + %int1_1262 = torch.constant.int 1 + %int4096_1263 = torch.constant.int 4096 + %1559 = torch.prim.ListConstruct %int4_1261, %int1_1262, %int4096_1263 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1560 = torch.aten.view %1558, %1559 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_1264 = torch.constant.int -2 + %int-1_1265 = torch.constant.int -1 + %1561 = torch.aten.transpose.int %67, %int-2_1264, %int-1_1265 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_1266 = torch.constant.int 5 + %1562 = torch.prims.convert_element_type %1561, %int5_1266 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_1267 = torch.constant.int 4 + %int4096_1268 = torch.constant.int 4096 + %1563 = torch.prim.ListConstruct %int4_1267, %int4096_1268 : (!torch.int, !torch.int) -> !torch.list + %1564 = torch.aten.view %1560, %1563 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1565 = torch.aten.mm %1564, %1562 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_1269 = torch.constant.int 4 + %int1_1270 = torch.constant.int 1 + %int4096_1271 = torch.constant.int 4096 + %1566 = torch.prim.ListConstruct %int4_1269, %int1_1270, %int4096_1271 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1567 = torch.aten.view %1565, %1566 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_1272 = torch.constant.int 1 + %1568 = torch.aten.add.Tensor %1389, %1567, %int1_1272 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_1273 = torch.constant.int 6 + %1569 = torch.prims.convert_element_type %1568, %int6_1273 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_1274 = torch.constant.int 2 + %1570 = torch.aten.pow.Tensor_Scalar %1569, %int2_1274 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_1275 = torch.constant.int -1 + %1571 = torch.prim.ListConstruct %int-1_1275 : (!torch.int) -> !torch.list + %true_1276 = torch.constant.bool true + %none_1277 = torch.constant.none + %1572 = torch.aten.mean.dim %1570, %1571, %true_1276, %none_1277 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_1278 = torch.constant.float 9.9999997473787516E-6 + %int1_1279 = torch.constant.int 1 + %1573 = torch.aten.add.Scalar %1572, %float9.999990e-06_1278, %int1_1279 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %1574 = torch.aten.rsqrt %1573 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %1575 = torch.aten.mul.Tensor %1569, %1574 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_1280 = torch.constant.int 5 + %1576 = torch.prims.convert_element_type %1575, %int5_1280 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %1577 = torch.aten.mul.Tensor %68, %1576 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_1281 = torch.constant.int 5 + %1578 = torch.prims.convert_element_type %1577, %int5_1281 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_1282 = torch.constant.int -2 + %int-1_1283 = torch.constant.int -1 + %1579 = torch.aten.transpose.int %69, %int-2_1282, %int-1_1283 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_1284 = torch.constant.int 5 + %1580 = torch.prims.convert_element_type %1579, %int5_1284 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_1285 = torch.constant.int 4 + %int4096_1286 = torch.constant.int 4096 + %1581 = torch.prim.ListConstruct %int4_1285, %int4096_1286 : (!torch.int, !torch.int) -> !torch.list + %1582 = torch.aten.view %1578, %1581 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1583 = torch.aten.mm %1582, %1580 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_1287 = torch.constant.int 4 %int1_1288 = torch.constant.int 1 - %1333 = torch.aten.slice.Tensor %1332, %int1_1285, %int0_1286, %int9223372036854775807_1287, %int1_1288 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1333, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_1289 = torch.constant.int 2 - %int0_1290 = torch.constant.int 0 - %1334 = torch.aten.select.int %1333, %int2_1289, %int0_1290 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1334, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_1291 = torch.constant.int 32 - %1335 = torch.aten.mul.int %358, %int32_1291 : !torch.int, !torch.int -> !torch.int - %int2_1292 = torch.constant.int 2 - %int0_1293 = torch.constant.int 0 - %int1_1294 = torch.constant.int 1 - %1336 = torch.aten.slice.Tensor %1334, %int2_1292, %int0_1293, %1335, %int1_1294 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1336, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_1295 = torch.constant.int 0 - %1337 = torch.aten.clone %1336, %int0_1295 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1337, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int14336_1289 = torch.constant.int 14336 + %1584 = torch.prim.ListConstruct %int4_1287, %int1_1288, %int14336_1289 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1585 = torch.aten.view %1583, %1584 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %1586 = torch.aten.silu %1585 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_1290 = torch.constant.int -2 + %int-1_1291 = torch.constant.int -1 + %1587 = torch.aten.transpose.int %70, %int-2_1290, %int-1_1291 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_1292 = torch.constant.int 5 + %1588 = torch.prims.convert_element_type %1587, %int5_1292 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_1293 = torch.constant.int 4 + %int4096_1294 = torch.constant.int 4096 + %1589 = torch.prim.ListConstruct %int4_1293, %int4096_1294 : (!torch.int, !torch.int) -> !torch.list + %1590 = torch.aten.view %1578, %1589 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1591 = torch.aten.mm %1590, %1588 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_1295 = torch.constant.int 4 %int1_1296 = torch.constant.int 1 - %1338 = torch.aten.size.int %1333, %int1_1296 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_1297 = torch.constant.int 32 - %1339 = torch.aten.mul.int %1338, %int32_1297 : !torch.int, !torch.int -> !torch.int - %int4_1298 = torch.constant.int 4 - %int8_1299 = torch.constant.int 8 - %int128_1300 = torch.constant.int 128 - %1340 = torch.prim.ListConstruct %int4_1298, %1339, %int8_1299, %int128_1300 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1341 = torch.aten._unsafe_view %1337, %1340 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1341, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_1301 = torch.constant.int 0 - %int0_1302 = torch.constant.int 0 - %int9223372036854775807_1303 = torch.constant.int 9223372036854775807 + %int14336_1297 = torch.constant.int 14336 + %1592 = torch.prim.ListConstruct %int4_1295, %int1_1296, %int14336_1297 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1593 = torch.aten.view %1591, %1592 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %1594 = torch.aten.mul.Tensor %1586, %1593 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_1298 = torch.constant.int -2 + %int-1_1299 = torch.constant.int -1 + %1595 = torch.aten.transpose.int %71, %int-2_1298, %int-1_1299 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_1300 = torch.constant.int 5 + %1596 = torch.prims.convert_element_type %1595, %int5_1300 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_1301 = torch.constant.int 4 + %int14336_1302 = torch.constant.int 14336 + %1597 = torch.prim.ListConstruct %int4_1301, %int14336_1302 : (!torch.int, !torch.int) -> !torch.list + %1598 = torch.aten.view %1594, %1597 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %1599 = torch.aten.mm %1598, %1596 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_1303 = torch.constant.int 4 %int1_1304 = torch.constant.int 1 - %1342 = torch.aten.slice.Tensor %1341, %int0_1301, %int0_1302, %int9223372036854775807_1303, %int1_1304 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1342, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_1305 = torch.constant.int 0 - %int0_1306 = torch.constant.int 0 - %int9223372036854775807_1307 = torch.constant.int 9223372036854775807 - %int1_1308 = torch.constant.int 1 - %1343 = torch.aten.slice.Tensor %1331, %int0_1305, %int0_1306, %int9223372036854775807_1307, %int1_1308 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1343, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_1309 = torch.constant.int 1 - %int0_1310 = torch.constant.int 0 - %int9223372036854775807_1311 = torch.constant.int 9223372036854775807 - %int1_1312 = torch.constant.int 1 - %1344 = torch.aten.slice.Tensor %1343, %int1_1309, %int0_1310, %int9223372036854775807_1311, %int1_1312 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1344, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_1313 = torch.constant.int 2 - %int1_1314 = torch.constant.int 1 - %1345 = torch.aten.select.int %1344, %int2_1313, %int1_1314 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1345, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_1315 = torch.constant.int 2 - %int0_1316 = torch.constant.int 0 - %int1_1317 = torch.constant.int 1 - %1346 = torch.aten.slice.Tensor %1345, %int2_1315, %int0_1316, %1335, %int1_1317 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1346, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_1318 = torch.constant.int 0 - %1347 = torch.aten.clone %1346, %int0_1318 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1347, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_1319 = torch.constant.int 1 - %1348 = torch.aten.size.int %1344, %int1_1319 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_1320 = torch.constant.int 32 - %1349 = torch.aten.mul.int %1348, %int32_1320 : !torch.int, !torch.int -> !torch.int + %int4096_1305 = torch.constant.int 4096 + %1600 = torch.prim.ListConstruct %int4_1303, %int1_1304, %int4096_1305 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1601 = torch.aten.view %1599, %1600 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_1306 = torch.constant.int 1 + %1602 = torch.aten.add.Tensor %1568, %1601, %int1_1306 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_1307 = torch.constant.int 6 + %1603 = torch.prims.convert_element_type %1602, %int6_1307 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_1308 = torch.constant.int 2 + %1604 = torch.aten.pow.Tensor_Scalar %1603, %int2_1308 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_1309 = torch.constant.int -1 + %1605 = torch.prim.ListConstruct %int-1_1309 : (!torch.int) -> !torch.list + %true_1310 = torch.constant.bool true + %none_1311 = torch.constant.none + %1606 = torch.aten.mean.dim %1604, %1605, %true_1310, %none_1311 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_1312 = torch.constant.float 9.9999997473787516E-6 + %int1_1313 = torch.constant.int 1 + %1607 = torch.aten.add.Scalar %1606, %float9.999990e-06_1312, %int1_1313 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %1608 = torch.aten.rsqrt %1607 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %1609 = torch.aten.mul.Tensor %1603, %1608 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_1314 = torch.constant.int 5 + %1610 = torch.prims.convert_element_type %1609, %int5_1314 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %1611 = torch.aten.mul.Tensor %72, %1610 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_1315 = torch.constant.int 5 + %1612 = torch.prims.convert_element_type %1611, %int5_1315 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_1316 = torch.constant.int -2 + %int-1_1317 = torch.constant.int -1 + %1613 = torch.aten.transpose.int %73, %int-2_1316, %int-1_1317 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_1318 = torch.constant.int 5 + %1614 = torch.prims.convert_element_type %1613, %int5_1318 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_1319 = torch.constant.int 4 + %int4096_1320 = torch.constant.int 4096 + %1615 = torch.prim.ListConstruct %int4_1319, %int4096_1320 : (!torch.int, !torch.int) -> !torch.list + %1616 = torch.aten.view %1612, %1615 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1617 = torch.aten.mm %1616, %1614 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_1321 = torch.constant.int 4 - %int8_1322 = torch.constant.int 8 - %int128_1323 = torch.constant.int 128 - %1350 = torch.prim.ListConstruct %int4_1321, %1349, %int8_1322, %int128_1323 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1351 = torch.aten._unsafe_view %1347, %1350 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1351, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_1324 = torch.constant.int 0 - %int0_1325 = torch.constant.int 0 - %int9223372036854775807_1326 = torch.constant.int 9223372036854775807 - %int1_1327 = torch.constant.int 1 - %1352 = torch.aten.slice.Tensor %1351, %int0_1324, %int0_1325, %int9223372036854775807_1326, %int1_1327 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1352, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_1328 = torch.constant.int -2 - %1353 = torch.aten.unsqueeze %1342, %int-2_1328 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1353, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_1329 = torch.constant.int 1 - %1354 = torch.aten.size.int %1341, %int1_1329 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_1330 = torch.constant.int 4 - %int8_1331 = torch.constant.int 8 - %int4_1332 = torch.constant.int 4 - %int128_1333 = torch.constant.int 128 - %1355 = torch.prim.ListConstruct %int4_1330, %1354, %int8_1331, %int4_1332, %int128_1333 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1334 = torch.constant.bool false - %1356 = torch.aten.expand %1353, %1355, %false_1334 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1356, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1335 = torch.constant.int 0 - %1357 = torch.aten.clone %1356, %int0_1335 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1357, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_1336 = torch.constant.int 4 - %int32_1337 = torch.constant.int 32 - %int128_1338 = torch.constant.int 128 - %1358 = torch.prim.ListConstruct %int4_1336, %1354, %int32_1337, %int128_1338 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1359 = torch.aten._unsafe_view %1357, %1358 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1359, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_1339 = torch.constant.int -2 - %1360 = torch.aten.unsqueeze %1352, %int-2_1339 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1360, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_1340 = torch.constant.int 1 - %1361 = torch.aten.size.int %1351, %int1_1340 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_1341 = torch.constant.int 4 - %int8_1342 = torch.constant.int 8 - %int4_1343 = torch.constant.int 4 - %int128_1344 = torch.constant.int 128 - %1362 = torch.prim.ListConstruct %int4_1341, %1361, %int8_1342, %int4_1343, %int128_1344 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1345 = torch.constant.bool false - %1363 = torch.aten.expand %1360, %1362, %false_1345 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1363, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1346 = torch.constant.int 0 - %1364 = torch.aten.clone %1363, %int0_1346 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1364, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_1347 = torch.constant.int 4 - %int32_1348 = torch.constant.int 32 - %int128_1349 = torch.constant.int 128 - %1365 = torch.prim.ListConstruct %int4_1347, %1361, %int32_1348, %int128_1349 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1366 = torch.aten._unsafe_view %1364, %1365 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1366, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_1350 = torch.constant.int 1 - %int2_1351 = torch.constant.int 2 - %1367 = torch.aten.transpose.int %1247, %int1_1350, %int2_1351 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_1322 = torch.constant.int 1 + %int4096_1323 = torch.constant.int 4096 + %1618 = torch.prim.ListConstruct %int4_1321, %int1_1322, %int4096_1323 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1619 = torch.aten.view %1617, %1618 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_1324 = torch.constant.int -2 + %int-1_1325 = torch.constant.int -1 + %1620 = torch.aten.transpose.int %74, %int-2_1324, %int-1_1325 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_1326 = torch.constant.int 5 + %1621 = torch.prims.convert_element_type %1620, %int5_1326 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_1327 = torch.constant.int 4 + %int4096_1328 = torch.constant.int 4096 + %1622 = torch.prim.ListConstruct %int4_1327, %int4096_1328 : (!torch.int, !torch.int) -> !torch.list + %1623 = torch.aten.view %1612, %1622 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1624 = torch.aten.mm %1623, %1621 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_1329 = torch.constant.int 4 + %int1_1330 = torch.constant.int 1 + %int1024_1331 = torch.constant.int 1024 + %1625 = torch.prim.ListConstruct %int4_1329, %int1_1330, %int1024_1331 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1626 = torch.aten.view %1624, %1625 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_1332 = torch.constant.int -2 + %int-1_1333 = torch.constant.int -1 + %1627 = torch.aten.transpose.int %75, %int-2_1332, %int-1_1333 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_1334 = torch.constant.int 5 + %1628 = torch.prims.convert_element_type %1627, %int5_1334 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_1335 = torch.constant.int 4 + %int4096_1336 = torch.constant.int 4096 + %1629 = torch.prim.ListConstruct %int4_1335, %int4096_1336 : (!torch.int, !torch.int) -> !torch.list + %1630 = torch.aten.view %1612, %1629 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1631 = torch.aten.mm %1630, %1628 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_1337 = torch.constant.int 4 + %int1_1338 = torch.constant.int 1 + %int1024_1339 = torch.constant.int 1024 + %1632 = torch.prim.ListConstruct %int4_1337, %int1_1338, %int1024_1339 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1633 = torch.aten.view %1631, %1632 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_1340 = torch.constant.int 4 + %int1_1341 = torch.constant.int 1 + %int32_1342 = torch.constant.int 32 + %int128_1343 = torch.constant.int 128 + %1634 = torch.prim.ListConstruct %int4_1340, %int1_1341, %int32_1342, %int128_1343 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1635 = torch.aten.view %1619, %1634 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_1344 = torch.constant.int 4 + %int1_1345 = torch.constant.int 1 + %int8_1346 = torch.constant.int 8 + %int128_1347 = torch.constant.int 128 + %1636 = torch.prim.ListConstruct %int4_1344, %int1_1345, %int8_1346, %int128_1347 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1637 = torch.aten.view %1626, %1636 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_1348 = torch.constant.int 4 + %int1_1349 = torch.constant.int 1 + %int8_1350 = torch.constant.int 8 + %int128_1351 = torch.constant.int 128 + %1638 = torch.prim.ListConstruct %int4_1348, %int1_1349, %int8_1350, %int128_1351 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1639 = torch.aten.view %1633, %1638 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> %int1_1352 = torch.constant.int 1 %int2_1353 = torch.constant.int 2 - %1368 = torch.aten.transpose.int %1359, %int1_1352, %int2_1353 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1368, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_1354 = torch.constant.int 1 - %int2_1355 = torch.constant.int 2 - %1369 = torch.aten.transpose.int %1366, %int1_1354, %int2_1355 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1369, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_1356 = torch.constant.float 0.000000e+00 - %false_1357 = torch.constant.bool false - %none_1358 = torch.constant.none - %1370:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1367, %1368, %1369, %float0.000000e00_1356, %false_1357, %368, %none_1358) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_1359 = torch.constant.int 1 - %int2_1360 = torch.constant.int 2 - %1371 = torch.aten.transpose.int %1370#0, %int1_1359, %int2_1360 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_1361 = torch.constant.int 4 - %int1_1362 = torch.constant.int 1 - %int4096_1363 = torch.constant.int 4096 - %1372 = torch.prim.ListConstruct %int4_1361, %int1_1362, %int4096_1363 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1373 = torch.aten.view %1371, %1372 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_1364 = torch.constant.int -2 - %int-1_1365 = torch.constant.int -1 - %1374 = torch.aten.transpose.int %51, %int-2_1364, %int-1_1365 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1366 = torch.constant.int 4 - %int4096_1367 = torch.constant.int 4096 - %1375 = torch.prim.ListConstruct %int4_1366, %int4096_1367 : (!torch.int, !torch.int) -> !torch.list - %1376 = torch.aten.view %1373, %1375 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1377 = torch.aten.mm %1376, %1374 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_1368 = torch.constant.int 4 - %int1_1369 = torch.constant.int 1 - %int4096_1370 = torch.constant.int 4096 - %1378 = torch.prim.ListConstruct %int4_1368, %int1_1369, %int4096_1370 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1379 = torch.aten.view %1377, %1378 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %1640 = torch.aten.transpose.int %1635, %int1_1352, %int2_1353 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %1641 = torch.aten.mul.Tensor %1640, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_1354 = torch.constant.int 3 + %int0_1355 = torch.constant.int 0 + %int64_1356 = torch.constant.int 64 + %int1_1357 = torch.constant.int 1 + %1642 = torch.aten.slice.Tensor %1640, %int3_1354, %int0_1355, %int64_1356, %int1_1357 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_1358 = torch.constant.int 3 + %int64_1359 = torch.constant.int 64 + %int9223372036854775807_1360 = torch.constant.int 9223372036854775807 + %int1_1361 = torch.constant.int 1 + %1643 = torch.aten.slice.Tensor %1640, %int3_1358, %int64_1359, %int9223372036854775807_1360, %int1_1361 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %1644 = torch.aten.neg %1643 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %1645 = torch.prim.ListConstruct %1644, %1642 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_1362 = torch.constant.int -1 + %1646 = torch.aten.cat %1645, %int-1_1362 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %1647 = torch.aten.mul.Tensor %1646, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_1363 = torch.constant.int 1 + %1648 = torch.aten.add.Tensor %1641, %1647, %int1_1363 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_1364 = torch.constant.int 1 + %int2_1365 = torch.constant.int 2 + %1649 = torch.aten.transpose.int %1648, %int1_1364, %int2_1365 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_1366 = torch.constant.int 1 + %int2_1367 = torch.constant.int 2 + %1650 = torch.aten.transpose.int %1637, %int1_1366, %int2_1367 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %1651 = torch.aten.mul.Tensor %1650, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_1368 = torch.constant.int 3 + %int0_1369 = torch.constant.int 0 + %int64_1370 = torch.constant.int 64 %int1_1371 = torch.constant.int 1 - %1380 = torch.aten.add.Tensor %1207, %1379, %int1_1371 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_1372 = torch.constant.int 6 - %1381 = torch.prims.convert_element_type %1380, %int6_1372 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_1373 = torch.constant.int 2 - %1382 = torch.aten.pow.Tensor_Scalar %1381, %int2_1373 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_1374 = torch.constant.int -1 - %1383 = torch.prim.ListConstruct %int-1_1374 : (!torch.int) -> !torch.list - %true_1375 = torch.constant.bool true - %none_1376 = torch.constant.none - %1384 = torch.aten.mean.dim %1382, %1383, %true_1375, %none_1376 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_1377 = torch.constant.float 9.9999997473787516E-6 + %1652 = torch.aten.slice.Tensor %1650, %int3_1368, %int0_1369, %int64_1370, %int1_1371 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_1372 = torch.constant.int 3 + %int64_1373 = torch.constant.int 64 + %int9223372036854775807_1374 = torch.constant.int 9223372036854775807 + %int1_1375 = torch.constant.int 1 + %1653 = torch.aten.slice.Tensor %1650, %int3_1372, %int64_1373, %int9223372036854775807_1374, %int1_1375 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %1654 = torch.aten.neg %1653 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %1655 = torch.prim.ListConstruct %1654, %1652 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_1376 = torch.constant.int -1 + %1656 = torch.aten.cat %1655, %int-1_1376 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %1657 = torch.aten.mul.Tensor %1656, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_1377 = torch.constant.int 1 + %1658 = torch.aten.add.Tensor %1651, %1657, %int1_1377 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> %int1_1378 = torch.constant.int 1 - %1385 = torch.aten.add.Scalar %1384, %float9.999990e-06_1377, %int1_1378 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %1386 = torch.aten.rsqrt %1385 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %1387 = torch.aten.mul.Tensor %1381, %1386 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_1379 = torch.constant.int 5 - %1388 = torch.prims.convert_element_type %1387, %int5_1379 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %1389 = torch.aten.mul.Tensor %52, %1388 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_1380 = torch.constant.int 5 - %1390 = torch.prims.convert_element_type %1389, %int5_1380 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_1381 = torch.constant.int -2 - %int-1_1382 = torch.constant.int -1 - %1391 = torch.aten.transpose.int %53, %int-2_1381, %int-1_1382 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1383 = torch.constant.int 4 - %int4096_1384 = torch.constant.int 4096 - %1392 = torch.prim.ListConstruct %int4_1383, %int4096_1384 : (!torch.int, !torch.int) -> !torch.list - %1393 = torch.aten.view %1390, %1392 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1394 = torch.aten.mm %1393, %1391 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_1385 = torch.constant.int 4 + %int2_1379 = torch.constant.int 2 + %1659 = torch.aten.transpose.int %1658, %int1_1378, %int2_1379 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_1380 = torch.constant.int 32 + %1660 = torch.aten.floor_divide.Scalar %arg2, %int32_1380 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_1381 = torch.constant.int 1 + %1661 = torch.aten.unsqueeze %1660, %int1_1381 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_1382 = torch.constant.int 1 + %false_1383 = torch.constant.bool false + %1662 = torch.aten.gather %arg3, %int1_1382, %1661, %false_1383 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_1384 = torch.constant.int 4 + %int1_1385 = torch.constant.int 1 %int1_1386 = torch.constant.int 1 - %int14336_1387 = torch.constant.int 14336 - %1395 = torch.prim.ListConstruct %int4_1385, %int1_1386, %int14336_1387 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1396 = torch.aten.view %1394, %1395 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %1397 = torch.aten.silu %1396 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_1388 = torch.constant.int -2 - %int-1_1389 = torch.constant.int -1 - %1398 = torch.aten.transpose.int %54, %int-2_1388, %int-1_1389 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1390 = torch.constant.int 4 - %int4096_1391 = torch.constant.int 4096 - %1399 = torch.prim.ListConstruct %int4_1390, %int4096_1391 : (!torch.int, !torch.int) -> !torch.list - %1400 = torch.aten.view %1390, %1399 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1401 = torch.aten.mm %1400, %1398 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_1392 = torch.constant.int 4 - %int1_1393 = torch.constant.int 1 - %int14336_1394 = torch.constant.int 14336 - %1402 = torch.prim.ListConstruct %int4_1392, %int1_1393, %int14336_1394 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1403 = torch.aten.view %1401, %1402 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %1404 = torch.aten.mul.Tensor %1397, %1403 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_1395 = torch.constant.int -2 - %int-1_1396 = torch.constant.int -1 - %1405 = torch.aten.transpose.int %55, %int-2_1395, %int-1_1396 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_1397 = torch.constant.int 4 - %int14336_1398 = torch.constant.int 14336 - %1406 = torch.prim.ListConstruct %int4_1397, %int14336_1398 : (!torch.int, !torch.int) -> !torch.list - %1407 = torch.aten.view %1404, %1406 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %1408 = torch.aten.mm %1407, %1405 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_1399 = torch.constant.int 4 + %1663 = torch.prim.ListConstruct %int4_1384, %int1_1385, %int1_1386 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1664 = torch.aten.view %1662, %1663 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_1387 = torch.constant.int 32 + %1665 = torch.aten.remainder.Scalar %arg2, %int32_1387 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_1388 = torch.constant.int 4 + %int1_1389 = torch.constant.int 1 + %int1_1390 = torch.constant.int 1 + %1666 = torch.prim.ListConstruct %int4_1388, %int1_1389, %int1_1390 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1667 = torch.aten.view %1665, %1666 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_1391 = torch.constant.int 8 + %none_1392 = torch.constant.none + %none_1393 = torch.constant.none + %cpu_1394 = torch.constant.device "cpu" + %false_1395 = torch.constant.bool false + %1668 = torch.aten.arange %int8_1391, %none_1392, %none_1393, %cpu_1394, %false_1395 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_1396 = torch.constant.int 1 + %int1_1397 = torch.constant.int 1 + %int8_1398 = torch.constant.int 8 + %1669 = torch.prim.ListConstruct %int1_1396, %int1_1397, %int8_1398 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1670 = torch.aten.view %1668, %1669 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_1399 = torch.constant.none + %1671 = torch.aten.clone %76, %none_1399 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1672 = torch.aten.detach %1671 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1673 = torch.aten.detach %1672 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1674 = torch.aten.detach %1673 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %int1_1400 = torch.constant.int 1 - %int4096_1401 = torch.constant.int 4096 - %1409 = torch.prim.ListConstruct %int4_1399, %int1_1400, %int4096_1401 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1410 = torch.aten.view %1408, %1409 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_1401 = torch.constant.int 1 %int1_1402 = torch.constant.int 1 - %1411 = torch.aten.add.Tensor %1380, %1410, %int1_1402 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_1403 = torch.constant.int 6 - %1412 = torch.prims.convert_element_type %1411, %int6_1403 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_1404 = torch.constant.int 2 - %1413 = torch.aten.pow.Tensor_Scalar %1412, %int2_1404 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_1405 = torch.constant.int -1 - %1414 = torch.prim.ListConstruct %int-1_1405 : (!torch.int) -> !torch.list - %true_1406 = torch.constant.bool true - %none_1407 = torch.constant.none - %1415 = torch.aten.mean.dim %1413, %1414, %true_1406, %none_1407 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_1408 = torch.constant.float 9.9999997473787516E-6 + %1675 = torch.prim.ListConstruct %int1_1400, %int1_1401, %int1_1402 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1676 = torch.aten.view %1674, %1675 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_1403 = torch.constant.int 32 + %1677 = torch.aten.mul.Scalar %1664, %int32_1403 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int5_1404 = torch.constant.int 5 + %int1_1405 = torch.constant.int 1 + %1678 = torch.aten.add.Scalar %1677, %int5_1404, %int1_1405 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_1406 = torch.constant.int 2 + %1679 = torch.aten.mul.Scalar %1678, %int2_1406 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_1407 = torch.constant.int 1 + %1680 = torch.aten.add.Tensor %1679, %1676, %int1_1407 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_1408 = torch.constant.int 8 + %1681 = torch.aten.mul.Scalar %1680, %int8_1408 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_1409 = torch.constant.int 1 - %1416 = torch.aten.add.Scalar %1415, %float9.999990e-06_1408, %int1_1409 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %1417 = torch.aten.rsqrt %1416 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %1418 = torch.aten.mul.Tensor %1412, %1417 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_1410 = torch.constant.int 5 - %1419 = torch.prims.convert_element_type %1418, %int5_1410 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %1420 = torch.aten.mul.Tensor %56, %1419 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_1411 = torch.constant.int 5 - %1421 = torch.prims.convert_element_type %1420, %int5_1411 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_1412 = torch.constant.int -2 - %int-1_1413 = torch.constant.int -1 - %1422 = torch.aten.transpose.int %57, %int-2_1412, %int-1_1413 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1414 = torch.constant.int 4 - %int4096_1415 = torch.constant.int 4096 - %1423 = torch.prim.ListConstruct %int4_1414, %int4096_1415 : (!torch.int, !torch.int) -> !torch.list - %1424 = torch.aten.view %1421, %1423 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1425 = torch.aten.mm %1424, %1422 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_1416 = torch.constant.int 4 - %int1_1417 = torch.constant.int 1 - %int4096_1418 = torch.constant.int 4096 - %1426 = torch.prim.ListConstruct %int4_1416, %int1_1417, %int4096_1418 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1427 = torch.aten.view %1425, %1426 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_1419 = torch.constant.int -2 - %int-1_1420 = torch.constant.int -1 - %1428 = torch.aten.transpose.int %58, %int-2_1419, %int-1_1420 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_1421 = torch.constant.int 4 - %int4096_1422 = torch.constant.int 4096 - %1429 = torch.prim.ListConstruct %int4_1421, %int4096_1422 : (!torch.int, !torch.int) -> !torch.list - %1430 = torch.aten.view %1421, %1429 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1431 = torch.aten.mm %1430, %1428 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_1423 = torch.constant.int 4 - %int1_1424 = torch.constant.int 1 - %int1024_1425 = torch.constant.int 1024 - %1432 = torch.prim.ListConstruct %int4_1423, %int1_1424, %int1024_1425 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1433 = torch.aten.view %1431, %1432 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_1426 = torch.constant.int -2 - %int-1_1427 = torch.constant.int -1 - %1434 = torch.aten.transpose.int %59, %int-2_1426, %int-1_1427 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_1428 = torch.constant.int 4 - %int4096_1429 = torch.constant.int 4096 - %1435 = torch.prim.ListConstruct %int4_1428, %int4096_1429 : (!torch.int, !torch.int) -> !torch.list - %1436 = torch.aten.view %1421, %1435 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1437 = torch.aten.mm %1436, %1434 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_1430 = torch.constant.int 4 - %int1_1431 = torch.constant.int 1 - %int1024_1432 = torch.constant.int 1024 - %1438 = torch.prim.ListConstruct %int4_1430, %int1_1431, %int1024_1432 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1439 = torch.aten.view %1437, %1438 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_1433 = torch.constant.int 4 + %1682 = torch.aten.add.Tensor %1681, %1670, %int1_1409 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_1410 = torch.constant.int 32 + %1683 = torch.aten.mul.Scalar %1682, %int32_1410 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_1411 = torch.constant.int 1 + %1684 = torch.aten.add.Tensor %1683, %1667, %int1_1411 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_1412 = torch.constant.int 5 + %1685 = torch.prims.convert_element_type %1659, %int5_1412 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_1413 = torch.constant.int 32 + %int2_1414 = torch.constant.int 2 + %int8_1415 = torch.constant.int 8 + %int32_1416 = torch.constant.int 32 + %int128_1417 = torch.constant.int 128 + %1686 = torch.prim.ListConstruct %456, %int32_1413, %int2_1414, %int8_1415, %int32_1416, %int128_1417 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1687 = torch.aten.view %1507, %1686 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1687, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_1418 = torch.constant.int 128 + %1688 = torch.prim.ListConstruct %596, %int128_1418 : (!torch.int, !torch.int) -> !torch.list + %1689 = torch.aten.view %1687, %1688 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1689, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %1690 = torch.prim.ListConstruct %1684 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_1419 = torch.constant.bool false + %1691 = torch.aten.index_put %1689, %1690, %1685, %false_1419 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1691, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_1420 = torch.constant.int 32 + %int2_1421 = torch.constant.int 2 + %int8_1422 = torch.constant.int 8 + %int32_1423 = torch.constant.int 32 + %int128_1424 = torch.constant.int 128 + %1692 = torch.prim.ListConstruct %456, %int32_1420, %int2_1421, %int8_1422, %int32_1423, %int128_1424 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1693 = torch.aten.view %1691, %1692 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1693, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_1425 = torch.constant.int 2097152 + %1694 = torch.prim.ListConstruct %456, %int2097152_1425 : (!torch.int, !torch.int) -> !torch.list + %1695 = torch.aten.view %1693, %1694 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1695, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_1426 = torch.constant.int 32 + %int2_1427 = torch.constant.int 2 + %int8_1428 = torch.constant.int 8 + %int32_1429 = torch.constant.int 32 + %int128_1430 = torch.constant.int 128 + %1696 = torch.prim.ListConstruct %456, %int32_1426, %int2_1427, %int8_1428, %int32_1429, %int128_1430 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1697 = torch.aten.view %1695, %1696 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1697, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_1431 = torch.constant.int 128 + %1698 = torch.prim.ListConstruct %596, %int128_1431 : (!torch.int, !torch.int) -> !torch.list + %1699 = torch.aten.view %1697, %1698 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1699, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_1432 = torch.constant.none + %1700 = torch.aten.clone %77, %none_1432 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1701 = torch.aten.detach %1700 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1702 = torch.aten.detach %1701 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1703 = torch.aten.detach %1702 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_1433 = torch.constant.int 1 %int1_1434 = torch.constant.int 1 - %int32_1435 = torch.constant.int 32 - %int128_1436 = torch.constant.int 128 - %1440 = torch.prim.ListConstruct %int4_1433, %int1_1434, %int32_1435, %int128_1436 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1441 = torch.aten.view %1427, %1440 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_1437 = torch.constant.int 4 + %int1_1435 = torch.constant.int 1 + %1704 = torch.prim.ListConstruct %int1_1433, %int1_1434, %int1_1435 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1705 = torch.aten.view %1703, %1704 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_1436 = torch.constant.int 32 + %1706 = torch.aten.mul.Scalar %1664, %int32_1436 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int5_1437 = torch.constant.int 5 %int1_1438 = torch.constant.int 1 - %int8_1439 = torch.constant.int 8 - %int128_1440 = torch.constant.int 128 - %1442 = torch.prim.ListConstruct %int4_1437, %int1_1438, %int8_1439, %int128_1440 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1443 = torch.aten.view %1433, %1442 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_1441 = torch.constant.int 4 + %1707 = torch.aten.add.Scalar %1706, %int5_1437, %int1_1438 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_1439 = torch.constant.int 2 + %1708 = torch.aten.mul.Scalar %1707, %int2_1439 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_1440 = torch.constant.int 1 + %1709 = torch.aten.add.Tensor %1708, %1705, %int1_1440 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_1441 = torch.constant.int 8 + %1710 = torch.aten.mul.Scalar %1709, %int8_1441 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_1442 = torch.constant.int 1 - %int8_1443 = torch.constant.int 8 - %int128_1444 = torch.constant.int 128 - %1444 = torch.prim.ListConstruct %int4_1441, %int1_1442, %int8_1443, %int128_1444 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1445 = torch.aten.view %1439, %1444 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_1445 = torch.constant.int 6 - %1446 = torch.prims.convert_element_type %1441, %int6_1445 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %1447 = torch_c.to_builtin_tensor %1446 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %1448 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %1449 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%1447, %1448) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %1450 = torch_c.from_builtin_tensor %1449 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_1446 = torch.constant.int 5 - %1451 = torch.prims.convert_element_type %1450, %int5_1446 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_1447 = torch.constant.int 6 - %1452 = torch.prims.convert_element_type %1443, %int6_1447 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %1453 = torch_c.to_builtin_tensor %1452 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %1454 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %1455 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%1453, %1454) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %1456 = torch_c.from_builtin_tensor %1455 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_1448 = torch.constant.int 5 - %1457 = torch.prims.convert_element_type %1456, %int5_1448 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_1449 = torch.constant.int 32 - %1458 = torch.aten.floor_divide.Scalar %arg2, %int32_1449 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_1450 = torch.constant.int 1 - %1459 = torch.aten.unsqueeze %1458, %int1_1450 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1451 = torch.constant.int 1 - %false_1452 = torch.constant.bool false - %1460 = torch.aten.gather %arg3, %int1_1451, %1459, %false_1452 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_1453 = torch.constant.int 32 - %1461 = torch.aten.remainder.Scalar %arg2, %int32_1453 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_1454 = torch.constant.int 1 - %1462 = torch.aten.unsqueeze %1461, %int1_1454 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %1711 = torch.aten.add.Tensor %1710, %1670, %int1_1442 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_1443 = torch.constant.int 32 + %1712 = torch.aten.mul.Scalar %1711, %int32_1443 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_1444 = torch.constant.int 1 + %1713 = torch.aten.add.Tensor %1712, %1667, %int1_1444 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_1445 = torch.constant.int 5 + %1714 = torch.prims.convert_element_type %1639, %int5_1445 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %1715 = torch.prim.ListConstruct %1713 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_1446 = torch.constant.bool false + %1716 = torch.aten.index_put %1699, %1715, %1714, %false_1446 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1716, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_1447 = torch.constant.int 32 + %int2_1448 = torch.constant.int 2 + %int8_1449 = torch.constant.int 8 + %int32_1450 = torch.constant.int 32 + %int128_1451 = torch.constant.int 128 + %1717 = torch.prim.ListConstruct %456, %int32_1447, %int2_1448, %int8_1449, %int32_1450, %int128_1451 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1718 = torch.aten.view %1716, %1717 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1718, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_1452 = torch.constant.int 2097152 + %1719 = torch.prim.ListConstruct %456, %int2097152_1452 : (!torch.int, !torch.int) -> !torch.list + %1720 = torch.aten.view %1718, %1719 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1720, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_1453 = torch.constant.none + %1721 = torch.aten.clone %78, %none_1453 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1722 = torch.aten.detach %1721 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1723 = torch.aten.detach %1722 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1724 = torch.aten.detach %1723 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_1454 = torch.constant.none + %1725 = torch.aten.clone %79, %none_1454 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1726 = torch.aten.detach %1725 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1727 = torch.aten.detach %1726 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1728 = torch.aten.detach %1727 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %none_1455 = torch.constant.none - %1463 = torch.aten.clone %60, %none_1455 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_1456 = torch.constant.int 0 - %1464 = torch.aten.unsqueeze %1463, %int0_1456 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_1457 = torch.constant.int 4 - %int1_1458 = torch.constant.int 1 - %1465 = torch.prim.ListConstruct %int4_1457, %int1_1458 : (!torch.int, !torch.int) -> !torch.list - %int1_1459 = torch.constant.int 1 - %int1_1460 = torch.constant.int 1 - %1466 = torch.prim.ListConstruct %int1_1459, %int1_1460 : (!torch.int, !torch.int) -> !torch.list - %int4_1461 = torch.constant.int 4 - %int0_1462 = torch.constant.int 0 - %cpu_1463 = torch.constant.device "cpu" - %false_1464 = torch.constant.bool false - %1467 = torch.aten.empty_strided %1465, %1466, %int4_1461, %int0_1462, %cpu_1463, %false_1464 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int5_1465 = torch.constant.int 5 - %1468 = torch.aten.fill.Scalar %1467, %int5_1465 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_1466 = torch.constant.int 4 - %int1_1467 = torch.constant.int 1 - %1469 = torch.prim.ListConstruct %int4_1466, %int1_1467 : (!torch.int, !torch.int) -> !torch.list - %1470 = torch.aten.repeat %1464, %1469 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_1468 = torch.constant.int 32 - %1471 = torch.aten.mul.Scalar %1460, %int32_1468 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1469 = torch.constant.int 1 - %1472 = torch.aten.add.Tensor %1471, %1468, %int1_1469 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_1470 = torch.constant.int 2 - %1473 = torch.aten.mul.Scalar %1472, %int2_1470 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1471 = torch.constant.int 1 - %1474 = torch.aten.add.Tensor %1473, %1470, %int1_1471 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_1472 = torch.constant.int 32 - %1475 = torch.aten.mul.Scalar %1474, %int32_1472 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1473 = torch.constant.int 1 - %1476 = torch.aten.add.Tensor %1475, %1462, %int1_1473 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_1474 = torch.constant.int 32 - %int2_1475 = torch.constant.int 2 - %int32_1476 = torch.constant.int 32 - %int8_1477 = torch.constant.int 8 - %int128_1478 = torch.constant.int 128 - %1477 = torch.prim.ListConstruct %437, %int32_1474, %int2_1475, %int32_1476, %int8_1477, %int128_1478 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1478 = torch.aten.view %1314, %1477 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1478, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_1479 = torch.constant.int 32 - %1479 = torch.aten.mul.int %437, %int32_1479 : !torch.int, !torch.int -> !torch.int - %int2_1480 = torch.constant.int 2 - %1480 = torch.aten.mul.int %1479, %int2_1480 : !torch.int, !torch.int -> !torch.int - %int32_1481 = torch.constant.int 32 - %1481 = torch.aten.mul.int %1480, %int32_1481 : !torch.int, !torch.int -> !torch.int - %int8_1482 = torch.constant.int 8 - %int128_1483 = torch.constant.int 128 - %1482 = torch.prim.ListConstruct %1481, %int8_1482, %int128_1483 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1483 = torch.aten.view %1478, %1482 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1483, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %1484 = torch.prim.ListConstruct %1476 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_1484 = torch.constant.bool false - %1485 = torch.aten.index_put %1483, %1484, %1457, %false_1484 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1485, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> + %1729 = torch.aten.clone %80, %none_1455 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1730 = torch.aten.detach %1729 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1731 = torch.aten.detach %1730 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1732 = torch.aten.detach %1731 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_1456 = torch.constant.int 32 + %int2_1457 = torch.constant.int 2 + %int8_1458 = torch.constant.int 8 + %int32_1459 = torch.constant.int 32 + %int128_1460 = torch.constant.int 128 + %1733 = torch.prim.ListConstruct %456, %int32_1456, %int2_1457, %int8_1458, %int32_1459, %int128_1460 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1734 = torch.aten.view %1720, %1733 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1734, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %1735 = torch_c.to_builtin_tensor %1734 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %1736 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_1461 = tensor.cast %1736 : tensor<4x?xi64> to tensor + %1737 = torch_c.to_builtin_tensor %1724 : !torch.vtensor<[],si64> -> tensor + %1738 = torch_c.to_builtin_tensor %1728 : !torch.vtensor<[],si64> -> tensor + %1739 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%1735, %cast_1461, %1737, %1738) : (tensor, tensor, tensor, tensor) -> tensor + %cast_1462 = tensor.cast %1739 : tensor to tensor<4x?x8x32x128xf16> + %1740 = torch_c.from_builtin_tensor %cast_1462 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %1740, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %1741 = torch_c.to_builtin_tensor %1734 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %1742 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_1463 = tensor.cast %1742 : tensor<4x?xi64> to tensor + %1743 = torch_c.to_builtin_tensor %1724 : !torch.vtensor<[],si64> -> tensor + %1744 = torch_c.to_builtin_tensor %1732 : !torch.vtensor<[],si64> -> tensor + %1745 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%1741, %cast_1463, %1743, %1744) : (tensor, tensor, tensor, tensor) -> tensor + %cast_1464 = tensor.cast %1745 : tensor to tensor<4x?x8x32x128xf16> + %1746 = torch_c.from_builtin_tensor %cast_1464 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %1746, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_1465 = torch.constant.int 2 + %int3_1466 = torch.constant.int 3 + %1747 = torch.aten.transpose.int %1740, %int2_1465, %int3_1466 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1747, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_1467 = torch.constant.int 0 + %1748 = torch.aten.clone %1747, %int0_1467 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1748, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_1468 = torch.constant.int 4 + %int8_1469 = torch.constant.int 8 + %int128_1470 = torch.constant.int 128 + %1749 = torch.prim.ListConstruct %int4_1468, %457, %int8_1469, %int128_1470 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1750 = torch.aten._unsafe_view %1748, %1749 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1750, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_1471 = torch.constant.int 2 + %int3_1472 = torch.constant.int 3 + %1751 = torch.aten.transpose.int %1746, %int2_1471, %int3_1472 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1751, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_1473 = torch.constant.int 0 + %1752 = torch.aten.clone %1751, %int0_1473 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1752, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_1474 = torch.constant.int 4 + %int8_1475 = torch.constant.int 8 + %int128_1476 = torch.constant.int 128 + %1753 = torch.prim.ListConstruct %int4_1474, %457, %int8_1475, %int128_1476 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1754 = torch.aten._unsafe_view %1752, %1753 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1754, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_1477 = torch.constant.int -2 + %1755 = torch.aten.unsqueeze %1750, %int-2_1477 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1755, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_1478 = torch.constant.int 4 + %int8_1479 = torch.constant.int 8 + %int4_1480 = torch.constant.int 4 + %int128_1481 = torch.constant.int 128 + %1756 = torch.prim.ListConstruct %int4_1478, %457, %int8_1479, %int4_1480, %int128_1481 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_1482 = torch.constant.bool false + %1757 = torch.aten.expand %1755, %1756, %false_1482 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1757, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_1483 = torch.constant.int 0 + %1758 = torch.aten.clone %1757, %int0_1483 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1758, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_1484 = torch.constant.int 4 %int32_1485 = torch.constant.int 32 - %int2_1486 = torch.constant.int 2 - %int32_1487 = torch.constant.int 32 - %int8_1488 = torch.constant.int 8 - %int128_1489 = torch.constant.int 128 - %1486 = torch.prim.ListConstruct %437, %int32_1485, %int2_1486, %int32_1487, %int8_1488, %int128_1489 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1487 = torch.aten.view %1485, %1486 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1487, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_1490 = torch.constant.int 2097152 - %1488 = torch.prim.ListConstruct %437, %int2097152_1490 : (!torch.int, !torch.int) -> !torch.list - %1489 = torch.aten.view %1487, %1488 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1489, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_1491 = torch.constant.int 32 - %int2_1492 = torch.constant.int 2 - %int32_1493 = torch.constant.int 32 - %int8_1494 = torch.constant.int 8 - %int128_1495 = torch.constant.int 128 - %1490 = torch.prim.ListConstruct %437, %int32_1491, %int2_1492, %int32_1493, %int8_1494, %int128_1495 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1491 = torch.aten.view %1489, %1490 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1491, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_1496 = torch.constant.int 8 - %int128_1497 = torch.constant.int 128 - %1492 = torch.prim.ListConstruct %1481, %int8_1496, %int128_1497 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1493 = torch.aten.view %1491, %1492 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1493, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_1498 = torch.constant.int 32 - %1494 = torch.aten.floor_divide.Scalar %arg2, %int32_1498 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int128_1486 = torch.constant.int 128 + %1759 = torch.prim.ListConstruct %int4_1484, %457, %int32_1485, %int128_1486 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1760 = torch.aten._unsafe_view %1758, %1759 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1760, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_1487 = torch.constant.int -2 + %1761 = torch.aten.unsqueeze %1754, %int-2_1487 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1761, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_1488 = torch.constant.int 4 + %int8_1489 = torch.constant.int 8 + %int4_1490 = torch.constant.int 4 + %int128_1491 = torch.constant.int 128 + %1762 = torch.prim.ListConstruct %int4_1488, %457, %int8_1489, %int4_1490, %int128_1491 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_1492 = torch.constant.bool false + %1763 = torch.aten.expand %1761, %1762, %false_1492 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1763, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_1493 = torch.constant.int 0 + %1764 = torch.aten.clone %1763, %int0_1493 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1764, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_1494 = torch.constant.int 4 + %int32_1495 = torch.constant.int 32 + %int128_1496 = torch.constant.int 128 + %1765 = torch.prim.ListConstruct %int4_1494, %457, %int32_1495, %int128_1496 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1766 = torch.aten._unsafe_view %1764, %1765 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1766, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_1497 = torch.constant.int 1 + %int2_1498 = torch.constant.int 2 + %1767 = torch.aten.transpose.int %1649, %int1_1497, %int2_1498 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_1499 = torch.constant.int 1 - %1495 = torch.aten.unsqueeze %1494, %int1_1499 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1500 = torch.constant.int 1 - %false_1501 = torch.constant.bool false - %1496 = torch.aten.gather %arg3, %int1_1500, %1495, %false_1501 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_1502 = torch.constant.int 32 - %1497 = torch.aten.remainder.Scalar %arg2, %int32_1502 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_1503 = torch.constant.int 1 - %1498 = torch.aten.unsqueeze %1497, %int1_1503 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_1504 = torch.constant.none - %1499 = torch.aten.clone %61, %none_1504 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_1505 = torch.constant.int 0 - %1500 = torch.aten.unsqueeze %1499, %int0_1505 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_1506 = torch.constant.int 4 - %int1_1507 = torch.constant.int 1 - %1501 = torch.prim.ListConstruct %int4_1506, %int1_1507 : (!torch.int, !torch.int) -> !torch.list - %int1_1508 = torch.constant.int 1 + %int2_1500 = torch.constant.int 2 + %1768 = torch.aten.transpose.int %1760, %int1_1499, %int2_1500 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1768, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_1501 = torch.constant.int 1 + %int2_1502 = torch.constant.int 2 + %1769 = torch.aten.transpose.int %1766, %int1_1501, %int2_1502 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1769, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_1503 = torch.constant.float 0.000000e+00 + %false_1504 = torch.constant.bool false + %none_1505 = torch.constant.none + %1770:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1767, %1768, %1769, %float0.000000e00_1503, %false_1504, %470, %none_1505) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_1506 = torch.constant.int 1 + %int2_1507 = torch.constant.int 2 + %1771 = torch.aten.transpose.int %1770#0, %int1_1506, %int2_1507 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_1508 = torch.constant.int 4 %int1_1509 = torch.constant.int 1 - %1502 = torch.prim.ListConstruct %int1_1508, %int1_1509 : (!torch.int, !torch.int) -> !torch.list - %int4_1510 = torch.constant.int 4 - %int0_1511 = torch.constant.int 0 - %cpu_1512 = torch.constant.device "cpu" - %false_1513 = torch.constant.bool false - %1503 = torch.aten.empty_strided %1501, %1502, %int4_1510, %int0_1511, %cpu_1512, %false_1513 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int5_1514 = torch.constant.int 5 - %1504 = torch.aten.fill.Scalar %1503, %int5_1514 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_1515 = torch.constant.int 4 - %int1_1516 = torch.constant.int 1 - %1505 = torch.prim.ListConstruct %int4_1515, %int1_1516 : (!torch.int, !torch.int) -> !torch.list - %1506 = torch.aten.repeat %1500, %1505 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_1517 = torch.constant.int 32 - %1507 = torch.aten.mul.Scalar %1496, %int32_1517 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1518 = torch.constant.int 1 - %1508 = torch.aten.add.Tensor %1507, %1504, %int1_1518 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_1519 = torch.constant.int 2 - %1509 = torch.aten.mul.Scalar %1508, %int2_1519 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1520 = torch.constant.int 1 - %1510 = torch.aten.add.Tensor %1509, %1506, %int1_1520 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_1521 = torch.constant.int 32 - %1511 = torch.aten.mul.Scalar %1510, %int32_1521 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1522 = torch.constant.int 1 - %1512 = torch.aten.add.Tensor %1511, %1498, %int1_1522 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %1513 = torch.prim.ListConstruct %1512 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_1523 = torch.constant.bool false - %1514 = torch.aten.index_put %1493, %1513, %1445, %false_1523 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1514, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_1524 = torch.constant.int 32 - %int2_1525 = torch.constant.int 2 - %int32_1526 = torch.constant.int 32 - %int8_1527 = torch.constant.int 8 - %int128_1528 = torch.constant.int 128 - %1515 = torch.prim.ListConstruct %437, %int32_1524, %int2_1525, %int32_1526, %int8_1527, %int128_1528 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1516 = torch.aten.view %1514, %1515 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1516, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_1529 = torch.constant.int 2097152 - %1517 = torch.prim.ListConstruct %437, %int2097152_1529 : (!torch.int, !torch.int) -> !torch.list - %1518 = torch.aten.view %1516, %1517 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1518, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_1530 = torch.constant.int 4 - %1519 = torch.prim.ListConstruct %int4_1530, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_1531 = torch.constant.int 1 - %1520 = torch.prim.ListConstruct %358, %int1_1531 : (!torch.int, !torch.int) -> !torch.list + %int4096_1510 = torch.constant.int 4096 + %1772 = torch.prim.ListConstruct %int4_1508, %int1_1509, %int4096_1510 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1773 = torch.aten.view %1771, %1772 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_1511 = torch.constant.int -2 + %int-1_1512 = torch.constant.int -1 + %1774 = torch.aten.transpose.int %81, %int-2_1511, %int-1_1512 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_1513 = torch.constant.int 5 + %1775 = torch.prims.convert_element_type %1774, %int5_1513 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_1514 = torch.constant.int 4 + %int4096_1515 = torch.constant.int 4096 + %1776 = torch.prim.ListConstruct %int4_1514, %int4096_1515 : (!torch.int, !torch.int) -> !torch.list + %1777 = torch.aten.view %1773, %1776 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1778 = torch.aten.mm %1777, %1775 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_1516 = torch.constant.int 4 + %int1_1517 = torch.constant.int 1 + %int4096_1518 = torch.constant.int 4096 + %1779 = torch.prim.ListConstruct %int4_1516, %int1_1517, %int4096_1518 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1780 = torch.aten.view %1778, %1779 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_1519 = torch.constant.int 1 + %1781 = torch.aten.add.Tensor %1602, %1780, %int1_1519 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_1520 = torch.constant.int 6 + %1782 = torch.prims.convert_element_type %1781, %int6_1520 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_1521 = torch.constant.int 2 + %1783 = torch.aten.pow.Tensor_Scalar %1782, %int2_1521 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_1522 = torch.constant.int -1 + %1784 = torch.prim.ListConstruct %int-1_1522 : (!torch.int) -> !torch.list + %true_1523 = torch.constant.bool true + %none_1524 = torch.constant.none + %1785 = torch.aten.mean.dim %1783, %1784, %true_1523, %none_1524 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_1525 = torch.constant.float 9.9999997473787516E-6 + %int1_1526 = torch.constant.int 1 + %1786 = torch.aten.add.Scalar %1785, %float9.999990e-06_1525, %int1_1526 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %1787 = torch.aten.rsqrt %1786 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %1788 = torch.aten.mul.Tensor %1782, %1787 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_1527 = torch.constant.int 5 + %1789 = torch.prims.convert_element_type %1788, %int5_1527 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %1790 = torch.aten.mul.Tensor %82, %1789 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_1528 = torch.constant.int 5 + %1791 = torch.prims.convert_element_type %1790, %int5_1528 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_1529 = torch.constant.int -2 + %int-1_1530 = torch.constant.int -1 + %1792 = torch.aten.transpose.int %83, %int-2_1529, %int-1_1530 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_1531 = torch.constant.int 5 + %1793 = torch.prims.convert_element_type %1792, %int5_1531 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> %int4_1532 = torch.constant.int 4 - %int0_1533 = torch.constant.int 0 - %cpu_1534 = torch.constant.device "cpu" - %false_1535 = torch.constant.bool false - %1521 = torch.aten.empty_strided %1519, %1520, %int4_1532, %int0_1533, %cpu_1534, %false_1535 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1521, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int5_1536 = torch.constant.int 5 - %1522 = torch.aten.fill.Scalar %1521, %int5_1536 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1522, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_1537 = torch.constant.int 32 - %1523 = torch.aten.mul.Scalar %arg3, %int32_1537 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1523, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_1538 = torch.constant.int 1 - %1524 = torch.aten.add.Tensor %1523, %1522, %int1_1538 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1524, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_1539 = torch.constant.int 4 - %1525 = torch.aten.mul.int %int4_1539, %358 : !torch.int, !torch.int -> !torch.int - %1526 = torch.prim.ListConstruct %1525 : (!torch.int) -> !torch.list - %1527 = torch.aten.view %1524, %1526 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1527, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_1540 = torch.constant.int 32 - %int2_1541 = torch.constant.int 2 - %int32_1542 = torch.constant.int 32 - %int8_1543 = torch.constant.int 8 - %int128_1544 = torch.constant.int 128 - %1528 = torch.prim.ListConstruct %437, %int32_1540, %int2_1541, %int32_1542, %int8_1543, %int128_1544 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1529 = torch.aten.view %1518, %1528 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1529, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_1545 = torch.constant.int 32 - %1530 = torch.aten.mul.int %437, %int32_1545 : !torch.int, !torch.int -> !torch.int - %int2_1546 = torch.constant.int 2 - %int32_1547 = torch.constant.int 32 - %int8_1548 = torch.constant.int 8 - %int128_1549 = torch.constant.int 128 - %1531 = torch.prim.ListConstruct %1530, %int2_1546, %int32_1547, %int8_1548, %int128_1549 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1532 = torch.aten.view %1529, %1531 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %1532, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_1550 = torch.constant.int 0 - %1533 = torch.aten.index_select %1532, %int0_1550, %1527 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %1533, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_1551 = torch.constant.int 4 - %int2_1552 = torch.constant.int 2 - %int32_1553 = torch.constant.int 32 - %int8_1554 = torch.constant.int 8 - %int128_1555 = torch.constant.int 128 - %1534 = torch.prim.ListConstruct %int4_1551, %358, %int2_1552, %int32_1553, %int8_1554, %int128_1555 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1535 = torch.aten.view %1533, %1534 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1535, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_1556 = torch.constant.int 0 - %int0_1557 = torch.constant.int 0 - %int9223372036854775807_1558 = torch.constant.int 9223372036854775807 - %int1_1559 = torch.constant.int 1 - %1536 = torch.aten.slice.Tensor %1535, %int0_1556, %int0_1557, %int9223372036854775807_1558, %int1_1559 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1536, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> + %int4096_1533 = torch.constant.int 4096 + %1794 = torch.prim.ListConstruct %int4_1532, %int4096_1533 : (!torch.int, !torch.int) -> !torch.list + %1795 = torch.aten.view %1791, %1794 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1796 = torch.aten.mm %1795, %1793 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_1534 = torch.constant.int 4 + %int1_1535 = torch.constant.int 1 + %int14336_1536 = torch.constant.int 14336 + %1797 = torch.prim.ListConstruct %int4_1534, %int1_1535, %int14336_1536 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1798 = torch.aten.view %1796, %1797 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %1799 = torch.aten.silu %1798 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_1537 = torch.constant.int -2 + %int-1_1538 = torch.constant.int -1 + %1800 = torch.aten.transpose.int %84, %int-2_1537, %int-1_1538 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_1539 = torch.constant.int 5 + %1801 = torch.prims.convert_element_type %1800, %int5_1539 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_1540 = torch.constant.int 4 + %int4096_1541 = torch.constant.int 4096 + %1802 = torch.prim.ListConstruct %int4_1540, %int4096_1541 : (!torch.int, !torch.int) -> !torch.list + %1803 = torch.aten.view %1791, %1802 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1804 = torch.aten.mm %1803, %1801 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_1542 = torch.constant.int 4 + %int1_1543 = torch.constant.int 1 + %int14336_1544 = torch.constant.int 14336 + %1805 = torch.prim.ListConstruct %int4_1542, %int1_1543, %int14336_1544 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1806 = torch.aten.view %1804, %1805 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %1807 = torch.aten.mul.Tensor %1799, %1806 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_1545 = torch.constant.int -2 + %int-1_1546 = torch.constant.int -1 + %1808 = torch.aten.transpose.int %85, %int-2_1545, %int-1_1546 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_1547 = torch.constant.int 5 + %1809 = torch.prims.convert_element_type %1808, %int5_1547 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_1548 = torch.constant.int 4 + %int14336_1549 = torch.constant.int 14336 + %1810 = torch.prim.ListConstruct %int4_1548, %int14336_1549 : (!torch.int, !torch.int) -> !torch.list + %1811 = torch.aten.view %1807, %1810 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %1812 = torch.aten.mm %1811, %1809 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_1550 = torch.constant.int 4 + %int1_1551 = torch.constant.int 1 + %int4096_1552 = torch.constant.int 4096 + %1813 = torch.prim.ListConstruct %int4_1550, %int1_1551, %int4096_1552 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1814 = torch.aten.view %1812, %1813 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_1553 = torch.constant.int 1 + %1815 = torch.aten.add.Tensor %1781, %1814, %int1_1553 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_1554 = torch.constant.int 6 + %1816 = torch.prims.convert_element_type %1815, %int6_1554 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_1555 = torch.constant.int 2 + %1817 = torch.aten.pow.Tensor_Scalar %1816, %int2_1555 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_1556 = torch.constant.int -1 + %1818 = torch.prim.ListConstruct %int-1_1556 : (!torch.int) -> !torch.list + %true_1557 = torch.constant.bool true + %none_1558 = torch.constant.none + %1819 = torch.aten.mean.dim %1817, %1818, %true_1557, %none_1558 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_1559 = torch.constant.float 9.9999997473787516E-6 %int1_1560 = torch.constant.int 1 - %int0_1561 = torch.constant.int 0 - %int9223372036854775807_1562 = torch.constant.int 9223372036854775807 - %int1_1563 = torch.constant.int 1 - %1537 = torch.aten.slice.Tensor %1536, %int1_1560, %int0_1561, %int9223372036854775807_1562, %int1_1563 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1537, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_1564 = torch.constant.int 2 - %int0_1565 = torch.constant.int 0 - %1538 = torch.aten.select.int %1537, %int2_1564, %int0_1565 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1538, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_1566 = torch.constant.int 32 - %1539 = torch.aten.mul.int %358, %int32_1566 : !torch.int, !torch.int -> !torch.int - %int2_1567 = torch.constant.int 2 - %int0_1568 = torch.constant.int 0 + %1820 = torch.aten.add.Scalar %1819, %float9.999990e-06_1559, %int1_1560 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %1821 = torch.aten.rsqrt %1820 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %1822 = torch.aten.mul.Tensor %1816, %1821 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_1561 = torch.constant.int 5 + %1823 = torch.prims.convert_element_type %1822, %int5_1561 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %1824 = torch.aten.mul.Tensor %86, %1823 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_1562 = torch.constant.int 5 + %1825 = torch.prims.convert_element_type %1824, %int5_1562 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_1563 = torch.constant.int -2 + %int-1_1564 = torch.constant.int -1 + %1826 = torch.aten.transpose.int %87, %int-2_1563, %int-1_1564 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_1565 = torch.constant.int 5 + %1827 = torch.prims.convert_element_type %1826, %int5_1565 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_1566 = torch.constant.int 4 + %int4096_1567 = torch.constant.int 4096 + %1828 = torch.prim.ListConstruct %int4_1566, %int4096_1567 : (!torch.int, !torch.int) -> !torch.list + %1829 = torch.aten.view %1825, %1828 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1830 = torch.aten.mm %1829, %1827 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_1568 = torch.constant.int 4 %int1_1569 = torch.constant.int 1 - %1540 = torch.aten.slice.Tensor %1538, %int2_1567, %int0_1568, %1539, %int1_1569 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1540, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_1570 = torch.constant.int 0 - %1541 = torch.aten.clone %1540, %int0_1570 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1541, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_1571 = torch.constant.int 1 - %1542 = torch.aten.size.int %1537, %int1_1571 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_1572 = torch.constant.int 32 - %1543 = torch.aten.mul.int %1542, %int32_1572 : !torch.int, !torch.int -> !torch.int - %int4_1573 = torch.constant.int 4 - %int8_1574 = torch.constant.int 8 - %int128_1575 = torch.constant.int 128 - %1544 = torch.prim.ListConstruct %int4_1573, %1543, %int8_1574, %int128_1575 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1545 = torch.aten._unsafe_view %1541, %1544 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1545, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_1576 = torch.constant.int 0 - %int0_1577 = torch.constant.int 0 - %int9223372036854775807_1578 = torch.constant.int 9223372036854775807 - %int1_1579 = torch.constant.int 1 - %1546 = torch.aten.slice.Tensor %1545, %int0_1576, %int0_1577, %int9223372036854775807_1578, %int1_1579 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1546, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_1580 = torch.constant.int 0 - %int0_1581 = torch.constant.int 0 - %int9223372036854775807_1582 = torch.constant.int 9223372036854775807 - %int1_1583 = torch.constant.int 1 - %1547 = torch.aten.slice.Tensor %1535, %int0_1580, %int0_1581, %int9223372036854775807_1582, %int1_1583 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1547, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_1584 = torch.constant.int 1 - %int0_1585 = torch.constant.int 0 - %int9223372036854775807_1586 = torch.constant.int 9223372036854775807 - %int1_1587 = torch.constant.int 1 - %1548 = torch.aten.slice.Tensor %1547, %int1_1584, %int0_1585, %int9223372036854775807_1586, %int1_1587 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1548, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_1588 = torch.constant.int 2 - %int1_1589 = torch.constant.int 1 - %1549 = torch.aten.select.int %1548, %int2_1588, %int1_1589 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1549, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_1590 = torch.constant.int 2 - %int0_1591 = torch.constant.int 0 + %int4096_1570 = torch.constant.int 4096 + %1831 = torch.prim.ListConstruct %int4_1568, %int1_1569, %int4096_1570 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1832 = torch.aten.view %1830, %1831 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_1571 = torch.constant.int -2 + %int-1_1572 = torch.constant.int -1 + %1833 = torch.aten.transpose.int %88, %int-2_1571, %int-1_1572 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_1573 = torch.constant.int 5 + %1834 = torch.prims.convert_element_type %1833, %int5_1573 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_1574 = torch.constant.int 4 + %int4096_1575 = torch.constant.int 4096 + %1835 = torch.prim.ListConstruct %int4_1574, %int4096_1575 : (!torch.int, !torch.int) -> !torch.list + %1836 = torch.aten.view %1825, %1835 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1837 = torch.aten.mm %1836, %1834 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_1576 = torch.constant.int 4 + %int1_1577 = torch.constant.int 1 + %int1024_1578 = torch.constant.int 1024 + %1838 = torch.prim.ListConstruct %int4_1576, %int1_1577, %int1024_1578 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1839 = torch.aten.view %1837, %1838 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_1579 = torch.constant.int -2 + %int-1_1580 = torch.constant.int -1 + %1840 = torch.aten.transpose.int %89, %int-2_1579, %int-1_1580 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_1581 = torch.constant.int 5 + %1841 = torch.prims.convert_element_type %1840, %int5_1581 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_1582 = torch.constant.int 4 + %int4096_1583 = torch.constant.int 4096 + %1842 = torch.prim.ListConstruct %int4_1582, %int4096_1583 : (!torch.int, !torch.int) -> !torch.list + %1843 = torch.aten.view %1825, %1842 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1844 = torch.aten.mm %1843, %1841 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_1584 = torch.constant.int 4 + %int1_1585 = torch.constant.int 1 + %int1024_1586 = torch.constant.int 1024 + %1845 = torch.prim.ListConstruct %int4_1584, %int1_1585, %int1024_1586 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1846 = torch.aten.view %1844, %1845 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_1587 = torch.constant.int 4 + %int1_1588 = torch.constant.int 1 + %int32_1589 = torch.constant.int 32 + %int128_1590 = torch.constant.int 128 + %1847 = torch.prim.ListConstruct %int4_1587, %int1_1588, %int32_1589, %int128_1590 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1848 = torch.aten.view %1832, %1847 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_1591 = torch.constant.int 4 %int1_1592 = torch.constant.int 1 - %1550 = torch.aten.slice.Tensor %1549, %int2_1590, %int0_1591, %1539, %int1_1592 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1550, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_1593 = torch.constant.int 0 - %1551 = torch.aten.clone %1550, %int0_1593 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1551, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_1594 = torch.constant.int 1 - %1552 = torch.aten.size.int %1548, %int1_1594 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_1595 = torch.constant.int 32 - %1553 = torch.aten.mul.int %1552, %int32_1595 : !torch.int, !torch.int -> !torch.int - %int4_1596 = torch.constant.int 4 + %int8_1593 = torch.constant.int 8 + %int128_1594 = torch.constant.int 128 + %1849 = torch.prim.ListConstruct %int4_1591, %int1_1592, %int8_1593, %int128_1594 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1850 = torch.aten.view %1839, %1849 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_1595 = torch.constant.int 4 + %int1_1596 = torch.constant.int 1 %int8_1597 = torch.constant.int 8 %int128_1598 = torch.constant.int 128 - %1554 = torch.prim.ListConstruct %int4_1596, %1553, %int8_1597, %int128_1598 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1555 = torch.aten._unsafe_view %1551, %1554 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1555, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_1599 = torch.constant.int 0 - %int0_1600 = torch.constant.int 0 - %int9223372036854775807_1601 = torch.constant.int 9223372036854775807 - %int1_1602 = torch.constant.int 1 - %1556 = torch.aten.slice.Tensor %1555, %int0_1599, %int0_1600, %int9223372036854775807_1601, %int1_1602 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1556, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_1603 = torch.constant.int -2 - %1557 = torch.aten.unsqueeze %1546, %int-2_1603 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1557, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %1851 = torch.prim.ListConstruct %int4_1595, %int1_1596, %int8_1597, %int128_1598 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1852 = torch.aten.view %1846, %1851 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_1599 = torch.constant.int 1 + %int2_1600 = torch.constant.int 2 + %1853 = torch.aten.transpose.int %1848, %int1_1599, %int2_1600 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %1854 = torch.aten.mul.Tensor %1853, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_1601 = torch.constant.int 3 + %int0_1602 = torch.constant.int 0 + %int64_1603 = torch.constant.int 64 %int1_1604 = torch.constant.int 1 - %1558 = torch.aten.size.int %1545, %int1_1604 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_1605 = torch.constant.int 4 - %int8_1606 = torch.constant.int 8 - %int4_1607 = torch.constant.int 4 - %int128_1608 = torch.constant.int 128 - %1559 = torch.prim.ListConstruct %int4_1605, %1558, %int8_1606, %int4_1607, %int128_1608 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1609 = torch.constant.bool false - %1560 = torch.aten.expand %1557, %1559, %false_1609 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1560, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1610 = torch.constant.int 0 - %1561 = torch.aten.clone %1560, %int0_1610 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1561, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_1611 = torch.constant.int 4 - %int32_1612 = torch.constant.int 32 - %int128_1613 = torch.constant.int 128 - %1562 = torch.prim.ListConstruct %int4_1611, %1558, %int32_1612, %int128_1613 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1563 = torch.aten._unsafe_view %1561, %1562 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1563, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_1614 = torch.constant.int -2 - %1564 = torch.aten.unsqueeze %1556, %int-2_1614 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1564, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_1615 = torch.constant.int 1 - %1565 = torch.aten.size.int %1555, %int1_1615 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_1616 = torch.constant.int 4 - %int8_1617 = torch.constant.int 8 - %int4_1618 = torch.constant.int 4 - %int128_1619 = torch.constant.int 128 - %1566 = torch.prim.ListConstruct %int4_1616, %1565, %int8_1617, %int4_1618, %int128_1619 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1620 = torch.constant.bool false - %1567 = torch.aten.expand %1564, %1566, %false_1620 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1567, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1621 = torch.constant.int 0 - %1568 = torch.aten.clone %1567, %int0_1621 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1568, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_1622 = torch.constant.int 4 - %int32_1623 = torch.constant.int 32 - %int128_1624 = torch.constant.int 128 - %1569 = torch.prim.ListConstruct %int4_1622, %1565, %int32_1623, %int128_1624 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1570 = torch.aten._unsafe_view %1568, %1569 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1570, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %1855 = torch.aten.slice.Tensor %1853, %int3_1601, %int0_1602, %int64_1603, %int1_1604 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_1605 = torch.constant.int 3 + %int64_1606 = torch.constant.int 64 + %int9223372036854775807_1607 = torch.constant.int 9223372036854775807 + %int1_1608 = torch.constant.int 1 + %1856 = torch.aten.slice.Tensor %1853, %int3_1605, %int64_1606, %int9223372036854775807_1607, %int1_1608 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %1857 = torch.aten.neg %1856 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %1858 = torch.prim.ListConstruct %1857, %1855 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_1609 = torch.constant.int -1 + %1859 = torch.aten.cat %1858, %int-1_1609 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %1860 = torch.aten.mul.Tensor %1859, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_1610 = torch.constant.int 1 + %1861 = torch.aten.add.Tensor %1854, %1860, %int1_1610 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_1611 = torch.constant.int 1 + %int2_1612 = torch.constant.int 2 + %1862 = torch.aten.transpose.int %1861, %int1_1611, %int2_1612 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_1613 = torch.constant.int 1 + %int2_1614 = torch.constant.int 2 + %1863 = torch.aten.transpose.int %1850, %int1_1613, %int2_1614 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %1864 = torch.aten.mul.Tensor %1863, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_1615 = torch.constant.int 3 + %int0_1616 = torch.constant.int 0 + %int64_1617 = torch.constant.int 64 + %int1_1618 = torch.constant.int 1 + %1865 = torch.aten.slice.Tensor %1863, %int3_1615, %int0_1616, %int64_1617, %int1_1618 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_1619 = torch.constant.int 3 + %int64_1620 = torch.constant.int 64 + %int9223372036854775807_1621 = torch.constant.int 9223372036854775807 + %int1_1622 = torch.constant.int 1 + %1866 = torch.aten.slice.Tensor %1863, %int3_1619, %int64_1620, %int9223372036854775807_1621, %int1_1622 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %1867 = torch.aten.neg %1866 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %1868 = torch.prim.ListConstruct %1867, %1865 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_1623 = torch.constant.int -1 + %1869 = torch.aten.cat %1868, %int-1_1623 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %1870 = torch.aten.mul.Tensor %1869, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_1624 = torch.constant.int 1 + %1871 = torch.aten.add.Tensor %1864, %1870, %int1_1624 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> %int1_1625 = torch.constant.int 1 %int2_1626 = torch.constant.int 2 - %1571 = torch.aten.transpose.int %1451, %int1_1625, %int2_1626 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_1627 = torch.constant.int 1 - %int2_1628 = torch.constant.int 2 - %1572 = torch.aten.transpose.int %1563, %int1_1627, %int2_1628 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1572, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %1872 = torch.aten.transpose.int %1871, %int1_1625, %int2_1626 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_1627 = torch.constant.int 32 + %1873 = torch.aten.floor_divide.Scalar %arg2, %int32_1627 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_1628 = torch.constant.int 1 + %1874 = torch.aten.unsqueeze %1873, %int1_1628 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> %int1_1629 = torch.constant.int 1 - %int2_1630 = torch.constant.int 2 - %1573 = torch.aten.transpose.int %1570, %int1_1629, %int2_1630 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1573, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_1631 = torch.constant.float 0.000000e+00 - %false_1632 = torch.constant.bool false - %none_1633 = torch.constant.none - %1574:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1571, %1572, %1573, %float0.000000e00_1631, %false_1632, %368, %none_1633) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_1634 = torch.constant.int 1 - %int2_1635 = torch.constant.int 2 - %1575 = torch.aten.transpose.int %1574#0, %int1_1634, %int2_1635 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_1636 = torch.constant.int 4 + %false_1630 = torch.constant.bool false + %1875 = torch.aten.gather %arg3, %int1_1629, %1874, %false_1630 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_1631 = torch.constant.int 4 + %int1_1632 = torch.constant.int 1 + %int1_1633 = torch.constant.int 1 + %1876 = torch.prim.ListConstruct %int4_1631, %int1_1632, %int1_1633 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1877 = torch.aten.view %1875, %1876 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_1634 = torch.constant.int 32 + %1878 = torch.aten.remainder.Scalar %arg2, %int32_1634 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_1635 = torch.constant.int 4 + %int1_1636 = torch.constant.int 1 %int1_1637 = torch.constant.int 1 - %int4096_1638 = torch.constant.int 4096 - %1576 = torch.prim.ListConstruct %int4_1636, %int1_1637, %int4096_1638 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1577 = torch.aten.view %1575, %1576 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_1639 = torch.constant.int -2 - %int-1_1640 = torch.constant.int -1 - %1578 = torch.aten.transpose.int %62, %int-2_1639, %int-1_1640 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1641 = torch.constant.int 4 - %int4096_1642 = torch.constant.int 4096 - %1579 = torch.prim.ListConstruct %int4_1641, %int4096_1642 : (!torch.int, !torch.int) -> !torch.list - %1580 = torch.aten.view %1577, %1579 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1581 = torch.aten.mm %1580, %1578 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_1643 = torch.constant.int 4 + %1879 = torch.prim.ListConstruct %int4_1635, %int1_1636, %int1_1637 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1880 = torch.aten.view %1878, %1879 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_1638 = torch.constant.int 8 + %none_1639 = torch.constant.none + %none_1640 = torch.constant.none + %cpu_1641 = torch.constant.device "cpu" + %false_1642 = torch.constant.bool false + %1881 = torch.aten.arange %int8_1638, %none_1639, %none_1640, %cpu_1641, %false_1642 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_1643 = torch.constant.int 1 %int1_1644 = torch.constant.int 1 - %int4096_1645 = torch.constant.int 4096 - %1582 = torch.prim.ListConstruct %int4_1643, %int1_1644, %int4096_1645 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1583 = torch.aten.view %1581, %1582 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_1646 = torch.constant.int 1 - %1584 = torch.aten.add.Tensor %1411, %1583, %int1_1646 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_1647 = torch.constant.int 6 - %1585 = torch.prims.convert_element_type %1584, %int6_1647 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_1648 = torch.constant.int 2 - %1586 = torch.aten.pow.Tensor_Scalar %1585, %int2_1648 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_1649 = torch.constant.int -1 - %1587 = torch.prim.ListConstruct %int-1_1649 : (!torch.int) -> !torch.list - %true_1650 = torch.constant.bool true - %none_1651 = torch.constant.none - %1588 = torch.aten.mean.dim %1586, %1587, %true_1650, %none_1651 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_1652 = torch.constant.float 9.9999997473787516E-6 - %int1_1653 = torch.constant.int 1 - %1589 = torch.aten.add.Scalar %1588, %float9.999990e-06_1652, %int1_1653 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %1590 = torch.aten.rsqrt %1589 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %1591 = torch.aten.mul.Tensor %1585, %1590 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_1654 = torch.constant.int 5 - %1592 = torch.prims.convert_element_type %1591, %int5_1654 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %1593 = torch.aten.mul.Tensor %63, %1592 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_1655 = torch.constant.int 5 - %1594 = torch.prims.convert_element_type %1593, %int5_1655 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_1656 = torch.constant.int -2 - %int-1_1657 = torch.constant.int -1 - %1595 = torch.aten.transpose.int %64, %int-2_1656, %int-1_1657 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1658 = torch.constant.int 4 - %int4096_1659 = torch.constant.int 4096 - %1596 = torch.prim.ListConstruct %int4_1658, %int4096_1659 : (!torch.int, !torch.int) -> !torch.list - %1597 = torch.aten.view %1594, %1596 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1598 = torch.aten.mm %1597, %1595 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_1660 = torch.constant.int 4 - %int1_1661 = torch.constant.int 1 - %int14336_1662 = torch.constant.int 14336 - %1599 = torch.prim.ListConstruct %int4_1660, %int1_1661, %int14336_1662 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1600 = torch.aten.view %1598, %1599 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %1601 = torch.aten.silu %1600 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_1663 = torch.constant.int -2 - %int-1_1664 = torch.constant.int -1 - %1602 = torch.aten.transpose.int %65, %int-2_1663, %int-1_1664 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1665 = torch.constant.int 4 - %int4096_1666 = torch.constant.int 4096 - %1603 = torch.prim.ListConstruct %int4_1665, %int4096_1666 : (!torch.int, !torch.int) -> !torch.list - %1604 = torch.aten.view %1594, %1603 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1605 = torch.aten.mm %1604, %1602 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_1667 = torch.constant.int 4 - %int1_1668 = torch.constant.int 1 - %int14336_1669 = torch.constant.int 14336 - %1606 = torch.prim.ListConstruct %int4_1667, %int1_1668, %int14336_1669 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1607 = torch.aten.view %1605, %1606 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %1608 = torch.aten.mul.Tensor %1601, %1607 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_1670 = torch.constant.int -2 - %int-1_1671 = torch.constant.int -1 - %1609 = torch.aten.transpose.int %66, %int-2_1670, %int-1_1671 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_1672 = torch.constant.int 4 - %int14336_1673 = torch.constant.int 14336 - %1610 = torch.prim.ListConstruct %int4_1672, %int14336_1673 : (!torch.int, !torch.int) -> !torch.list - %1611 = torch.aten.view %1608, %1610 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %1612 = torch.aten.mm %1611, %1609 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_1674 = torch.constant.int 4 - %int1_1675 = torch.constant.int 1 - %int4096_1676 = torch.constant.int 4096 - %1613 = torch.prim.ListConstruct %int4_1674, %int1_1675, %int4096_1676 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1614 = torch.aten.view %1612, %1613 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_1677 = torch.constant.int 1 - %1615 = torch.aten.add.Tensor %1584, %1614, %int1_1677 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_1678 = torch.constant.int 6 - %1616 = torch.prims.convert_element_type %1615, %int6_1678 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_1679 = torch.constant.int 2 - %1617 = torch.aten.pow.Tensor_Scalar %1616, %int2_1679 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_1680 = torch.constant.int -1 - %1618 = torch.prim.ListConstruct %int-1_1680 : (!torch.int) -> !torch.list - %true_1681 = torch.constant.bool true - %none_1682 = torch.constant.none - %1619 = torch.aten.mean.dim %1617, %1618, %true_1681, %none_1682 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_1683 = torch.constant.float 9.9999997473787516E-6 - %int1_1684 = torch.constant.int 1 - %1620 = torch.aten.add.Scalar %1619, %float9.999990e-06_1683, %int1_1684 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %1621 = torch.aten.rsqrt %1620 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %1622 = torch.aten.mul.Tensor %1616, %1621 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_1685 = torch.constant.int 5 - %1623 = torch.prims.convert_element_type %1622, %int5_1685 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %1624 = torch.aten.mul.Tensor %67, %1623 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_1686 = torch.constant.int 5 - %1625 = torch.prims.convert_element_type %1624, %int5_1686 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_1687 = torch.constant.int -2 - %int-1_1688 = torch.constant.int -1 - %1626 = torch.aten.transpose.int %68, %int-2_1687, %int-1_1688 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1689 = torch.constant.int 4 - %int4096_1690 = torch.constant.int 4096 - %1627 = torch.prim.ListConstruct %int4_1689, %int4096_1690 : (!torch.int, !torch.int) -> !torch.list - %1628 = torch.aten.view %1625, %1627 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1629 = torch.aten.mm %1628, %1626 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_1691 = torch.constant.int 4 - %int1_1692 = torch.constant.int 1 - %int4096_1693 = torch.constant.int 4096 - %1630 = torch.prim.ListConstruct %int4_1691, %int1_1692, %int4096_1693 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1631 = torch.aten.view %1629, %1630 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_1694 = torch.constant.int -2 - %int-1_1695 = torch.constant.int -1 - %1632 = torch.aten.transpose.int %69, %int-2_1694, %int-1_1695 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_1696 = torch.constant.int 4 - %int4096_1697 = torch.constant.int 4096 - %1633 = torch.prim.ListConstruct %int4_1696, %int4096_1697 : (!torch.int, !torch.int) -> !torch.list - %1634 = torch.aten.view %1625, %1633 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1635 = torch.aten.mm %1634, %1632 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_1698 = torch.constant.int 4 - %int1_1699 = torch.constant.int 1 - %int1024_1700 = torch.constant.int 1024 - %1636 = torch.prim.ListConstruct %int4_1698, %int1_1699, %int1024_1700 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1637 = torch.aten.view %1635, %1636 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_1701 = torch.constant.int -2 - %int-1_1702 = torch.constant.int -1 - %1638 = torch.aten.transpose.int %70, %int-2_1701, %int-1_1702 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_1703 = torch.constant.int 4 - %int4096_1704 = torch.constant.int 4096 - %1639 = torch.prim.ListConstruct %int4_1703, %int4096_1704 : (!torch.int, !torch.int) -> !torch.list - %1640 = torch.aten.view %1625, %1639 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1641 = torch.aten.mm %1640, %1638 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_1705 = torch.constant.int 4 - %int1_1706 = torch.constant.int 1 - %int1024_1707 = torch.constant.int 1024 - %1642 = torch.prim.ListConstruct %int4_1705, %int1_1706, %int1024_1707 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1643 = torch.aten.view %1641, %1642 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_1708 = torch.constant.int 4 - %int1_1709 = torch.constant.int 1 - %int32_1710 = torch.constant.int 32 - %int128_1711 = torch.constant.int 128 - %1644 = torch.prim.ListConstruct %int4_1708, %int1_1709, %int32_1710, %int128_1711 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1645 = torch.aten.view %1631, %1644 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_1712 = torch.constant.int 4 - %int1_1713 = torch.constant.int 1 - %int8_1714 = torch.constant.int 8 - %int128_1715 = torch.constant.int 128 - %1646 = torch.prim.ListConstruct %int4_1712, %int1_1713, %int8_1714, %int128_1715 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1647 = torch.aten.view %1637, %1646 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_1716 = torch.constant.int 4 - %int1_1717 = torch.constant.int 1 - %int8_1718 = torch.constant.int 8 - %int128_1719 = torch.constant.int 128 - %1648 = torch.prim.ListConstruct %int4_1716, %int1_1717, %int8_1718, %int128_1719 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1649 = torch.aten.view %1643, %1648 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_1720 = torch.constant.int 6 - %1650 = torch.prims.convert_element_type %1645, %int6_1720 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %1651 = torch_c.to_builtin_tensor %1650 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %1652 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %1653 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%1651, %1652) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %1654 = torch_c.from_builtin_tensor %1653 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_1721 = torch.constant.int 5 - %1655 = torch.prims.convert_element_type %1654, %int5_1721 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_1722 = torch.constant.int 6 - %1656 = torch.prims.convert_element_type %1647, %int6_1722 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %1657 = torch_c.to_builtin_tensor %1656 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %1658 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %1659 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%1657, %1658) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %1660 = torch_c.from_builtin_tensor %1659 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_1723 = torch.constant.int 5 - %1661 = torch.prims.convert_element_type %1660, %int5_1723 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_1724 = torch.constant.int 32 - %1662 = torch.aten.floor_divide.Scalar %arg2, %int32_1724 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_1725 = torch.constant.int 1 - %1663 = torch.aten.unsqueeze %1662, %int1_1725 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1726 = torch.constant.int 1 - %false_1727 = torch.constant.bool false - %1664 = torch.aten.gather %arg3, %int1_1726, %1663, %false_1727 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_1728 = torch.constant.int 32 - %1665 = torch.aten.remainder.Scalar %arg2, %int32_1728 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_1729 = torch.constant.int 1 - %1666 = torch.aten.unsqueeze %1665, %int1_1729 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_1730 = torch.constant.none - %1667 = torch.aten.clone %71, %none_1730 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_1731 = torch.constant.int 0 - %1668 = torch.aten.unsqueeze %1667, %int0_1731 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_1732 = torch.constant.int 4 - %int1_1733 = torch.constant.int 1 - %1669 = torch.prim.ListConstruct %int4_1732, %int1_1733 : (!torch.int, !torch.int) -> !torch.list - %int1_1734 = torch.constant.int 1 - %int1_1735 = torch.constant.int 1 - %1670 = torch.prim.ListConstruct %int1_1734, %int1_1735 : (!torch.int, !torch.int) -> !torch.list - %int4_1736 = torch.constant.int 4 - %int0_1737 = torch.constant.int 0 - %cpu_1738 = torch.constant.device "cpu" + %int8_1645 = torch.constant.int 8 + %1882 = torch.prim.ListConstruct %int1_1643, %int1_1644, %int8_1645 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1883 = torch.aten.view %1881, %1882 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_1646 = torch.constant.none + %1884 = torch.aten.clone %90, %none_1646 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1885 = torch.aten.detach %1884 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1886 = torch.aten.detach %1885 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1887 = torch.aten.detach %1886 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_1647 = torch.constant.int 1 + %int1_1648 = torch.constant.int 1 + %int1_1649 = torch.constant.int 1 + %1888 = torch.prim.ListConstruct %int1_1647, %int1_1648, %int1_1649 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1889 = torch.aten.view %1887, %1888 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_1650 = torch.constant.int 32 + %1890 = torch.aten.mul.Scalar %1877, %int32_1650 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int6_1651 = torch.constant.int 6 + %int1_1652 = torch.constant.int 1 + %1891 = torch.aten.add.Scalar %1890, %int6_1651, %int1_1652 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_1653 = torch.constant.int 2 + %1892 = torch.aten.mul.Scalar %1891, %int2_1653 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_1654 = torch.constant.int 1 + %1893 = torch.aten.add.Tensor %1892, %1889, %int1_1654 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_1655 = torch.constant.int 8 + %1894 = torch.aten.mul.Scalar %1893, %int8_1655 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_1656 = torch.constant.int 1 + %1895 = torch.aten.add.Tensor %1894, %1883, %int1_1656 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_1657 = torch.constant.int 32 + %1896 = torch.aten.mul.Scalar %1895, %int32_1657 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_1658 = torch.constant.int 1 + %1897 = torch.aten.add.Tensor %1896, %1880, %int1_1658 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_1659 = torch.constant.int 5 + %1898 = torch.prims.convert_element_type %1872, %int5_1659 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_1660 = torch.constant.int 32 + %int2_1661 = torch.constant.int 2 + %int8_1662 = torch.constant.int 8 + %int32_1663 = torch.constant.int 32 + %int128_1664 = torch.constant.int 128 + %1899 = torch.prim.ListConstruct %456, %int32_1660, %int2_1661, %int8_1662, %int32_1663, %int128_1664 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1900 = torch.aten.view %1720, %1899 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1900, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_1665 = torch.constant.int 128 + %1901 = torch.prim.ListConstruct %596, %int128_1665 : (!torch.int, !torch.int) -> !torch.list + %1902 = torch.aten.view %1900, %1901 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1902, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %1903 = torch.prim.ListConstruct %1897 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_1666 = torch.constant.bool false + %1904 = torch.aten.index_put %1902, %1903, %1898, %false_1666 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1904, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_1667 = torch.constant.int 32 + %int2_1668 = torch.constant.int 2 + %int8_1669 = torch.constant.int 8 + %int32_1670 = torch.constant.int 32 + %int128_1671 = torch.constant.int 128 + %1905 = torch.prim.ListConstruct %456, %int32_1667, %int2_1668, %int8_1669, %int32_1670, %int128_1671 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1906 = torch.aten.view %1904, %1905 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1906, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_1672 = torch.constant.int 2097152 + %1907 = torch.prim.ListConstruct %456, %int2097152_1672 : (!torch.int, !torch.int) -> !torch.list + %1908 = torch.aten.view %1906, %1907 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1908, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_1673 = torch.constant.int 32 + %int2_1674 = torch.constant.int 2 + %int8_1675 = torch.constant.int 8 + %int32_1676 = torch.constant.int 32 + %int128_1677 = torch.constant.int 128 + %1909 = torch.prim.ListConstruct %456, %int32_1673, %int2_1674, %int8_1675, %int32_1676, %int128_1677 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1910 = torch.aten.view %1908, %1909 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1910, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_1678 = torch.constant.int 128 + %1911 = torch.prim.ListConstruct %596, %int128_1678 : (!torch.int, !torch.int) -> !torch.list + %1912 = torch.aten.view %1910, %1911 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1912, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_1679 = torch.constant.none + %1913 = torch.aten.clone %91, %none_1679 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1914 = torch.aten.detach %1913 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1915 = torch.aten.detach %1914 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1916 = torch.aten.detach %1915 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_1680 = torch.constant.int 1 + %int1_1681 = torch.constant.int 1 + %int1_1682 = torch.constant.int 1 + %1917 = torch.prim.ListConstruct %int1_1680, %int1_1681, %int1_1682 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1918 = torch.aten.view %1916, %1917 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_1683 = torch.constant.int 32 + %1919 = torch.aten.mul.Scalar %1877, %int32_1683 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int6_1684 = torch.constant.int 6 + %int1_1685 = torch.constant.int 1 + %1920 = torch.aten.add.Scalar %1919, %int6_1684, %int1_1685 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_1686 = torch.constant.int 2 + %1921 = torch.aten.mul.Scalar %1920, %int2_1686 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_1687 = torch.constant.int 1 + %1922 = torch.aten.add.Tensor %1921, %1918, %int1_1687 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_1688 = torch.constant.int 8 + %1923 = torch.aten.mul.Scalar %1922, %int8_1688 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_1689 = torch.constant.int 1 + %1924 = torch.aten.add.Tensor %1923, %1883, %int1_1689 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_1690 = torch.constant.int 32 + %1925 = torch.aten.mul.Scalar %1924, %int32_1690 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_1691 = torch.constant.int 1 + %1926 = torch.aten.add.Tensor %1925, %1880, %int1_1691 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_1692 = torch.constant.int 5 + %1927 = torch.prims.convert_element_type %1852, %int5_1692 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %1928 = torch.prim.ListConstruct %1926 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_1693 = torch.constant.bool false + %1929 = torch.aten.index_put %1912, %1928, %1927, %false_1693 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %1929, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_1694 = torch.constant.int 32 + %int2_1695 = torch.constant.int 2 + %int8_1696 = torch.constant.int 8 + %int32_1697 = torch.constant.int 32 + %int128_1698 = torch.constant.int 128 + %1930 = torch.prim.ListConstruct %456, %int32_1694, %int2_1695, %int8_1696, %int32_1697, %int128_1698 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1931 = torch.aten.view %1929, %1930 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1931, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_1699 = torch.constant.int 2097152 + %1932 = torch.prim.ListConstruct %456, %int2097152_1699 : (!torch.int, !torch.int) -> !torch.list + %1933 = torch.aten.view %1931, %1932 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %1933, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_1700 = torch.constant.none + %1934 = torch.aten.clone %92, %none_1700 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1935 = torch.aten.detach %1934 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1936 = torch.aten.detach %1935 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1937 = torch.aten.detach %1936 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_1701 = torch.constant.none + %1938 = torch.aten.clone %93, %none_1701 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1939 = torch.aten.detach %1938 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1940 = torch.aten.detach %1939 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1941 = torch.aten.detach %1940 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_1702 = torch.constant.none + %1942 = torch.aten.clone %94, %none_1702 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %1943 = torch.aten.detach %1942 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1944 = torch.aten.detach %1943 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %1945 = torch.aten.detach %1944 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_1703 = torch.constant.int 32 + %int2_1704 = torch.constant.int 2 + %int8_1705 = torch.constant.int 8 + %int32_1706 = torch.constant.int 32 + %int128_1707 = torch.constant.int 128 + %1946 = torch.prim.ListConstruct %456, %int32_1703, %int2_1704, %int8_1705, %int32_1706, %int128_1707 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1947 = torch.aten.view %1933, %1946 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %1947, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %1948 = torch_c.to_builtin_tensor %1947 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %1949 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_1708 = tensor.cast %1949 : tensor<4x?xi64> to tensor + %1950 = torch_c.to_builtin_tensor %1937 : !torch.vtensor<[],si64> -> tensor + %1951 = torch_c.to_builtin_tensor %1941 : !torch.vtensor<[],si64> -> tensor + %1952 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%1948, %cast_1708, %1950, %1951) : (tensor, tensor, tensor, tensor) -> tensor + %cast_1709 = tensor.cast %1952 : tensor to tensor<4x?x8x32x128xf16> + %1953 = torch_c.from_builtin_tensor %cast_1709 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %1953, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %1954 = torch_c.to_builtin_tensor %1947 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %1955 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_1710 = tensor.cast %1955 : tensor<4x?xi64> to tensor + %1956 = torch_c.to_builtin_tensor %1937 : !torch.vtensor<[],si64> -> tensor + %1957 = torch_c.to_builtin_tensor %1945 : !torch.vtensor<[],si64> -> tensor + %1958 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%1954, %cast_1710, %1956, %1957) : (tensor, tensor, tensor, tensor) -> tensor + %cast_1711 = tensor.cast %1958 : tensor to tensor<4x?x8x32x128xf16> + %1959 = torch_c.from_builtin_tensor %cast_1711 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %1959, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_1712 = torch.constant.int 2 + %int3_1713 = torch.constant.int 3 + %1960 = torch.aten.transpose.int %1953, %int2_1712, %int3_1713 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1960, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_1714 = torch.constant.int 0 + %1961 = torch.aten.clone %1960, %int0_1714 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1961, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_1715 = torch.constant.int 4 + %int8_1716 = torch.constant.int 8 + %int128_1717 = torch.constant.int 128 + %1962 = torch.prim.ListConstruct %int4_1715, %457, %int8_1716, %int128_1717 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1963 = torch.aten._unsafe_view %1961, %1962 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1963, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_1718 = torch.constant.int 2 + %int3_1719 = torch.constant.int 3 + %1964 = torch.aten.transpose.int %1959, %int2_1718, %int3_1719 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1964, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_1720 = torch.constant.int 0 + %1965 = torch.aten.clone %1964, %int0_1720 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %1965, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_1721 = torch.constant.int 4 + %int8_1722 = torch.constant.int 8 + %int128_1723 = torch.constant.int 128 + %1966 = torch.prim.ListConstruct %int4_1721, %457, %int8_1722, %int128_1723 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1967 = torch.aten._unsafe_view %1965, %1966 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %1967, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_1724 = torch.constant.int -2 + %1968 = torch.aten.unsqueeze %1963, %int-2_1724 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1968, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_1725 = torch.constant.int 4 + %int8_1726 = torch.constant.int 8 + %int4_1727 = torch.constant.int 4 + %int128_1728 = torch.constant.int 128 + %1969 = torch.prim.ListConstruct %int4_1725, %457, %int8_1726, %int4_1727, %int128_1728 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_1729 = torch.constant.bool false + %1970 = torch.aten.expand %1968, %1969, %false_1729 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1970, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_1730 = torch.constant.int 0 + %1971 = torch.aten.clone %1970, %int0_1730 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1971, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_1731 = torch.constant.int 4 + %int32_1732 = torch.constant.int 32 + %int128_1733 = torch.constant.int 128 + %1972 = torch.prim.ListConstruct %int4_1731, %457, %int32_1732, %int128_1733 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1973 = torch.aten._unsafe_view %1971, %1972 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1973, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_1734 = torch.constant.int -2 + %1974 = torch.aten.unsqueeze %1967, %int-2_1734 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %1974, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_1735 = torch.constant.int 4 + %int8_1736 = torch.constant.int 8 + %int4_1737 = torch.constant.int 4 + %int128_1738 = torch.constant.int 128 + %1975 = torch.prim.ListConstruct %int4_1735, %457, %int8_1736, %int4_1737, %int128_1738 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %false_1739 = torch.constant.bool false - %1671 = torch.aten.empty_strided %1669, %1670, %int4_1736, %int0_1737, %cpu_1738, %false_1739 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int6_1740 = torch.constant.int 6 - %1672 = torch.aten.fill.Scalar %1671, %int6_1740 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %1976 = torch.aten.expand %1974, %1975, %false_1739 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1976, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_1740 = torch.constant.int 0 + %1977 = torch.aten.clone %1976, %int0_1740 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %1977, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_1741 = torch.constant.int 4 - %int1_1742 = torch.constant.int 1 - %1673 = torch.prim.ListConstruct %int4_1741, %int1_1742 : (!torch.int, !torch.int) -> !torch.list - %1674 = torch.aten.repeat %1668, %1673 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_1743 = torch.constant.int 32 - %1675 = torch.aten.mul.Scalar %1664, %int32_1743 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int32_1742 = torch.constant.int 32 + %int128_1743 = torch.constant.int 128 + %1978 = torch.prim.ListConstruct %int4_1741, %457, %int32_1742, %int128_1743 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1979 = torch.aten._unsafe_view %1977, %1978 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %1979, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int1_1744 = torch.constant.int 1 - %1676 = torch.aten.add.Tensor %1675, %1672, %int1_1744 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> %int2_1745 = torch.constant.int 2 - %1677 = torch.aten.mul.Scalar %1676, %int2_1745 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %1980 = torch.aten.transpose.int %1862, %int1_1744, %int2_1745 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_1746 = torch.constant.int 1 - %1678 = torch.aten.add.Tensor %1677, %1674, %int1_1746 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_1747 = torch.constant.int 32 - %1679 = torch.aten.mul.Scalar %1678, %int32_1747 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int2_1747 = torch.constant.int 2 + %1981 = torch.aten.transpose.int %1973, %int1_1746, %int2_1747 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1981, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_1748 = torch.constant.int 1 - %1680 = torch.aten.add.Tensor %1679, %1666, %int1_1748 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_1749 = torch.constant.int 32 - %int2_1750 = torch.constant.int 2 - %int32_1751 = torch.constant.int 32 - %int8_1752 = torch.constant.int 8 - %int128_1753 = torch.constant.int 128 - %1681 = torch.prim.ListConstruct %437, %int32_1749, %int2_1750, %int32_1751, %int8_1752, %int128_1753 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1682 = torch.aten.view %1518, %1681 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1682, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_1754 = torch.constant.int 32 - %1683 = torch.aten.mul.int %437, %int32_1754 : !torch.int, !torch.int -> !torch.int - %int2_1755 = torch.constant.int 2 - %1684 = torch.aten.mul.int %1683, %int2_1755 : !torch.int, !torch.int -> !torch.int - %int32_1756 = torch.constant.int 32 - %1685 = torch.aten.mul.int %1684, %int32_1756 : !torch.int, !torch.int -> !torch.int - %int8_1757 = torch.constant.int 8 - %int128_1758 = torch.constant.int 128 - %1686 = torch.prim.ListConstruct %1685, %int8_1757, %int128_1758 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1687 = torch.aten.view %1682, %1686 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1687, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %1688 = torch.prim.ListConstruct %1680 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_1759 = torch.constant.bool false - %1689 = torch.aten.index_put %1687, %1688, %1661, %false_1759 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1689, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_1760 = torch.constant.int 32 - %int2_1761 = torch.constant.int 2 - %int32_1762 = torch.constant.int 32 - %int8_1763 = torch.constant.int 8 - %int128_1764 = torch.constant.int 128 - %1690 = torch.prim.ListConstruct %437, %int32_1760, %int2_1761, %int32_1762, %int8_1763, %int128_1764 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1691 = torch.aten.view %1689, %1690 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1691, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_1765 = torch.constant.int 2097152 - %1692 = torch.prim.ListConstruct %437, %int2097152_1765 : (!torch.int, !torch.int) -> !torch.list - %1693 = torch.aten.view %1691, %1692 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1693, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_1766 = torch.constant.int 32 - %int2_1767 = torch.constant.int 2 - %int32_1768 = torch.constant.int 32 - %int8_1769 = torch.constant.int 8 - %int128_1770 = torch.constant.int 128 - %1694 = torch.prim.ListConstruct %437, %int32_1766, %int2_1767, %int32_1768, %int8_1769, %int128_1770 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1695 = torch.aten.view %1693, %1694 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1695, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_1771 = torch.constant.int 8 - %int128_1772 = torch.constant.int 128 - %1696 = torch.prim.ListConstruct %1685, %int8_1771, %int128_1772 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1697 = torch.aten.view %1695, %1696 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1697, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_1773 = torch.constant.int 32 - %1698 = torch.aten.floor_divide.Scalar %arg2, %int32_1773 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_1774 = torch.constant.int 1 - %1699 = torch.aten.unsqueeze %1698, %int1_1774 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1775 = torch.constant.int 1 - %false_1776 = torch.constant.bool false - %1700 = torch.aten.gather %arg3, %int1_1775, %1699, %false_1776 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_1777 = torch.constant.int 32 - %1701 = torch.aten.remainder.Scalar %arg2, %int32_1777 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_1778 = torch.constant.int 1 - %1702 = torch.aten.unsqueeze %1701, %int1_1778 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_1779 = torch.constant.none - %1703 = torch.aten.clone %72, %none_1779 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_1780 = torch.constant.int 0 - %1704 = torch.aten.unsqueeze %1703, %int0_1780 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %int2_1749 = torch.constant.int 2 + %1982 = torch.aten.transpose.int %1979, %int1_1748, %int2_1749 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %1982, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_1750 = torch.constant.float 0.000000e+00 + %false_1751 = torch.constant.bool false + %none_1752 = torch.constant.none + %1983:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1980, %1981, %1982, %float0.000000e00_1750, %false_1751, %470, %none_1752) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_1753 = torch.constant.int 1 + %int2_1754 = torch.constant.int 2 + %1984 = torch.aten.transpose.int %1983#0, %int1_1753, %int2_1754 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_1755 = torch.constant.int 4 + %int1_1756 = torch.constant.int 1 + %int4096_1757 = torch.constant.int 4096 + %1985 = torch.prim.ListConstruct %int4_1755, %int1_1756, %int4096_1757 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1986 = torch.aten.view %1984, %1985 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_1758 = torch.constant.int -2 + %int-1_1759 = torch.constant.int -1 + %1987 = torch.aten.transpose.int %95, %int-2_1758, %int-1_1759 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_1760 = torch.constant.int 5 + %1988 = torch.prims.convert_element_type %1987, %int5_1760 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_1761 = torch.constant.int 4 + %int4096_1762 = torch.constant.int 4096 + %1989 = torch.prim.ListConstruct %int4_1761, %int4096_1762 : (!torch.int, !torch.int) -> !torch.list + %1990 = torch.aten.view %1986, %1989 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %1991 = torch.aten.mm %1990, %1988 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_1763 = torch.constant.int 4 + %int1_1764 = torch.constant.int 1 + %int4096_1765 = torch.constant.int 4096 + %1992 = torch.prim.ListConstruct %int4_1763, %int1_1764, %int4096_1765 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1993 = torch.aten.view %1991, %1992 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_1766 = torch.constant.int 1 + %1994 = torch.aten.add.Tensor %1815, %1993, %int1_1766 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_1767 = torch.constant.int 6 + %1995 = torch.prims.convert_element_type %1994, %int6_1767 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_1768 = torch.constant.int 2 + %1996 = torch.aten.pow.Tensor_Scalar %1995, %int2_1768 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_1769 = torch.constant.int -1 + %1997 = torch.prim.ListConstruct %int-1_1769 : (!torch.int) -> !torch.list + %true_1770 = torch.constant.bool true + %none_1771 = torch.constant.none + %1998 = torch.aten.mean.dim %1996, %1997, %true_1770, %none_1771 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_1772 = torch.constant.float 9.9999997473787516E-6 + %int1_1773 = torch.constant.int 1 + %1999 = torch.aten.add.Scalar %1998, %float9.999990e-06_1772, %int1_1773 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %2000 = torch.aten.rsqrt %1999 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %2001 = torch.aten.mul.Tensor %1995, %2000 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_1774 = torch.constant.int 5 + %2002 = torch.prims.convert_element_type %2001, %int5_1774 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %2003 = torch.aten.mul.Tensor %96, %2002 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_1775 = torch.constant.int 5 + %2004 = torch.prims.convert_element_type %2003, %int5_1775 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_1776 = torch.constant.int -2 + %int-1_1777 = torch.constant.int -1 + %2005 = torch.aten.transpose.int %97, %int-2_1776, %int-1_1777 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_1778 = torch.constant.int 5 + %2006 = torch.prims.convert_element_type %2005, %int5_1778 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_1779 = torch.constant.int 4 + %int4096_1780 = torch.constant.int 4096 + %2007 = torch.prim.ListConstruct %int4_1779, %int4096_1780 : (!torch.int, !torch.int) -> !torch.list + %2008 = torch.aten.view %2004, %2007 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2009 = torch.aten.mm %2008, %2006 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> %int4_1781 = torch.constant.int 4 %int1_1782 = torch.constant.int 1 - %1705 = torch.prim.ListConstruct %int4_1781, %int1_1782 : (!torch.int, !torch.int) -> !torch.list - %int1_1783 = torch.constant.int 1 - %int1_1784 = torch.constant.int 1 - %1706 = torch.prim.ListConstruct %int1_1783, %int1_1784 : (!torch.int, !torch.int) -> !torch.list - %int4_1785 = torch.constant.int 4 - %int0_1786 = torch.constant.int 0 - %cpu_1787 = torch.constant.device "cpu" - %false_1788 = torch.constant.bool false - %1707 = torch.aten.empty_strided %1705, %1706, %int4_1785, %int0_1786, %cpu_1787, %false_1788 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int6_1789 = torch.constant.int 6 - %1708 = torch.aten.fill.Scalar %1707, %int6_1789 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_1790 = torch.constant.int 4 - %int1_1791 = torch.constant.int 1 - %1709 = torch.prim.ListConstruct %int4_1790, %int1_1791 : (!torch.int, !torch.int) -> !torch.list - %1710 = torch.aten.repeat %1704, %1709 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_1792 = torch.constant.int 32 - %1711 = torch.aten.mul.Scalar %1700, %int32_1792 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1793 = torch.constant.int 1 - %1712 = torch.aten.add.Tensor %1711, %1708, %int1_1793 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_1794 = torch.constant.int 2 - %1713 = torch.aten.mul.Scalar %1712, %int2_1794 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1795 = torch.constant.int 1 - %1714 = torch.aten.add.Tensor %1713, %1710, %int1_1795 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_1796 = torch.constant.int 32 - %1715 = torch.aten.mul.Scalar %1714, %int32_1796 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_1797 = torch.constant.int 1 - %1716 = torch.aten.add.Tensor %1715, %1702, %int1_1797 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %1717 = torch.prim.ListConstruct %1716 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_1798 = torch.constant.bool false - %1718 = torch.aten.index_put %1697, %1717, %1649, %false_1798 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1718, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_1799 = torch.constant.int 32 - %int2_1800 = torch.constant.int 2 - %int32_1801 = torch.constant.int 32 - %int8_1802 = torch.constant.int 8 - %int128_1803 = torch.constant.int 128 - %1719 = torch.prim.ListConstruct %437, %int32_1799, %int2_1800, %int32_1801, %int8_1802, %int128_1803 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1720 = torch.aten.view %1718, %1719 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1720, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_1804 = torch.constant.int 2097152 - %1721 = torch.prim.ListConstruct %437, %int2097152_1804 : (!torch.int, !torch.int) -> !torch.list - %1722 = torch.aten.view %1720, %1721 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1722, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_1805 = torch.constant.int 4 - %1723 = torch.prim.ListConstruct %int4_1805, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_1806 = torch.constant.int 1 - %1724 = torch.prim.ListConstruct %358, %int1_1806 : (!torch.int, !torch.int) -> !torch.list - %int4_1807 = torch.constant.int 4 - %int0_1808 = torch.constant.int 0 - %cpu_1809 = torch.constant.device "cpu" - %false_1810 = torch.constant.bool false - %1725 = torch.aten.empty_strided %1723, %1724, %int4_1807, %int0_1808, %cpu_1809, %false_1810 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1725, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int6_1811 = torch.constant.int 6 - %1726 = torch.aten.fill.Scalar %1725, %int6_1811 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1726, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_1812 = torch.constant.int 32 - %1727 = torch.aten.mul.Scalar %arg3, %int32_1812 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1727, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_1813 = torch.constant.int 1 - %1728 = torch.aten.add.Tensor %1727, %1726, %int1_1813 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1728, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_1814 = torch.constant.int 4 - %1729 = torch.aten.mul.int %int4_1814, %358 : !torch.int, !torch.int -> !torch.int - %1730 = torch.prim.ListConstruct %1729 : (!torch.int) -> !torch.list - %1731 = torch.aten.view %1728, %1730 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1731, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_1815 = torch.constant.int 32 - %int2_1816 = torch.constant.int 2 - %int32_1817 = torch.constant.int 32 - %int8_1818 = torch.constant.int 8 - %int128_1819 = torch.constant.int 128 - %1732 = torch.prim.ListConstruct %437, %int32_1815, %int2_1816, %int32_1817, %int8_1818, %int128_1819 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1733 = torch.aten.view %1722, %1732 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1733, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_1820 = torch.constant.int 32 - %1734 = torch.aten.mul.int %437, %int32_1820 : !torch.int, !torch.int -> !torch.int - %int2_1821 = torch.constant.int 2 - %int32_1822 = torch.constant.int 32 - %int8_1823 = torch.constant.int 8 - %int128_1824 = torch.constant.int 128 - %1735 = torch.prim.ListConstruct %1734, %int2_1821, %int32_1822, %int8_1823, %int128_1824 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1736 = torch.aten.view %1733, %1735 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %1736, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_1825 = torch.constant.int 0 - %1737 = torch.aten.index_select %1736, %int0_1825, %1731 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %1737, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_1826 = torch.constant.int 4 - %int2_1827 = torch.constant.int 2 - %int32_1828 = torch.constant.int 32 - %int8_1829 = torch.constant.int 8 - %int128_1830 = torch.constant.int 128 - %1738 = torch.prim.ListConstruct %int4_1826, %358, %int2_1827, %int32_1828, %int8_1829, %int128_1830 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1739 = torch.aten.view %1737, %1738 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1739, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_1831 = torch.constant.int 0 - %int0_1832 = torch.constant.int 0 - %int9223372036854775807_1833 = torch.constant.int 9223372036854775807 - %int1_1834 = torch.constant.int 1 - %1740 = torch.aten.slice.Tensor %1739, %int0_1831, %int0_1832, %int9223372036854775807_1833, %int1_1834 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1740, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> + %int14336_1783 = torch.constant.int 14336 + %2010 = torch.prim.ListConstruct %int4_1781, %int1_1782, %int14336_1783 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2011 = torch.aten.view %2009, %2010 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %2012 = torch.aten.silu %2011 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_1784 = torch.constant.int -2 + %int-1_1785 = torch.constant.int -1 + %2013 = torch.aten.transpose.int %98, %int-2_1784, %int-1_1785 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_1786 = torch.constant.int 5 + %2014 = torch.prims.convert_element_type %2013, %int5_1786 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_1787 = torch.constant.int 4 + %int4096_1788 = torch.constant.int 4096 + %2015 = torch.prim.ListConstruct %int4_1787, %int4096_1788 : (!torch.int, !torch.int) -> !torch.list + %2016 = torch.aten.view %2004, %2015 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2017 = torch.aten.mm %2016, %2014 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_1789 = torch.constant.int 4 + %int1_1790 = torch.constant.int 1 + %int14336_1791 = torch.constant.int 14336 + %2018 = torch.prim.ListConstruct %int4_1789, %int1_1790, %int14336_1791 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2019 = torch.aten.view %2017, %2018 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %2020 = torch.aten.mul.Tensor %2012, %2019 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_1792 = torch.constant.int -2 + %int-1_1793 = torch.constant.int -1 + %2021 = torch.aten.transpose.int %99, %int-2_1792, %int-1_1793 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_1794 = torch.constant.int 5 + %2022 = torch.prims.convert_element_type %2021, %int5_1794 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_1795 = torch.constant.int 4 + %int14336_1796 = torch.constant.int 14336 + %2023 = torch.prim.ListConstruct %int4_1795, %int14336_1796 : (!torch.int, !torch.int) -> !torch.list + %2024 = torch.aten.view %2020, %2023 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %2025 = torch.aten.mm %2024, %2022 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_1797 = torch.constant.int 4 + %int1_1798 = torch.constant.int 1 + %int4096_1799 = torch.constant.int 4096 + %2026 = torch.prim.ListConstruct %int4_1797, %int1_1798, %int4096_1799 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2027 = torch.aten.view %2025, %2026 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_1800 = torch.constant.int 1 + %2028 = torch.aten.add.Tensor %1994, %2027, %int1_1800 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_1801 = torch.constant.int 6 + %2029 = torch.prims.convert_element_type %2028, %int6_1801 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_1802 = torch.constant.int 2 + %2030 = torch.aten.pow.Tensor_Scalar %2029, %int2_1802 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_1803 = torch.constant.int -1 + %2031 = torch.prim.ListConstruct %int-1_1803 : (!torch.int) -> !torch.list + %true_1804 = torch.constant.bool true + %none_1805 = torch.constant.none + %2032 = torch.aten.mean.dim %2030, %2031, %true_1804, %none_1805 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_1806 = torch.constant.float 9.9999997473787516E-6 + %int1_1807 = torch.constant.int 1 + %2033 = torch.aten.add.Scalar %2032, %float9.999990e-06_1806, %int1_1807 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %2034 = torch.aten.rsqrt %2033 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %2035 = torch.aten.mul.Tensor %2029, %2034 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_1808 = torch.constant.int 5 + %2036 = torch.prims.convert_element_type %2035, %int5_1808 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %2037 = torch.aten.mul.Tensor %100, %2036 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_1809 = torch.constant.int 5 + %2038 = torch.prims.convert_element_type %2037, %int5_1809 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_1810 = torch.constant.int -2 + %int-1_1811 = torch.constant.int -1 + %2039 = torch.aten.transpose.int %101, %int-2_1810, %int-1_1811 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_1812 = torch.constant.int 5 + %2040 = torch.prims.convert_element_type %2039, %int5_1812 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_1813 = torch.constant.int 4 + %int4096_1814 = torch.constant.int 4096 + %2041 = torch.prim.ListConstruct %int4_1813, %int4096_1814 : (!torch.int, !torch.int) -> !torch.list + %2042 = torch.aten.view %2038, %2041 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2043 = torch.aten.mm %2042, %2040 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_1815 = torch.constant.int 4 + %int1_1816 = torch.constant.int 1 + %int4096_1817 = torch.constant.int 4096 + %2044 = torch.prim.ListConstruct %int4_1815, %int1_1816, %int4096_1817 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2045 = torch.aten.view %2043, %2044 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_1818 = torch.constant.int -2 + %int-1_1819 = torch.constant.int -1 + %2046 = torch.aten.transpose.int %102, %int-2_1818, %int-1_1819 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_1820 = torch.constant.int 5 + %2047 = torch.prims.convert_element_type %2046, %int5_1820 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_1821 = torch.constant.int 4 + %int4096_1822 = torch.constant.int 4096 + %2048 = torch.prim.ListConstruct %int4_1821, %int4096_1822 : (!torch.int, !torch.int) -> !torch.list + %2049 = torch.aten.view %2038, %2048 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2050 = torch.aten.mm %2049, %2047 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_1823 = torch.constant.int 4 + %int1_1824 = torch.constant.int 1 + %int1024_1825 = torch.constant.int 1024 + %2051 = torch.prim.ListConstruct %int4_1823, %int1_1824, %int1024_1825 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2052 = torch.aten.view %2050, %2051 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_1826 = torch.constant.int -2 + %int-1_1827 = torch.constant.int -1 + %2053 = torch.aten.transpose.int %103, %int-2_1826, %int-1_1827 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_1828 = torch.constant.int 5 + %2054 = torch.prims.convert_element_type %2053, %int5_1828 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_1829 = torch.constant.int 4 + %int4096_1830 = torch.constant.int 4096 + %2055 = torch.prim.ListConstruct %int4_1829, %int4096_1830 : (!torch.int, !torch.int) -> !torch.list + %2056 = torch.aten.view %2038, %2055 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2057 = torch.aten.mm %2056, %2054 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_1831 = torch.constant.int 4 + %int1_1832 = torch.constant.int 1 + %int1024_1833 = torch.constant.int 1024 + %2058 = torch.prim.ListConstruct %int4_1831, %int1_1832, %int1024_1833 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2059 = torch.aten.view %2057, %2058 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_1834 = torch.constant.int 4 %int1_1835 = torch.constant.int 1 - %int0_1836 = torch.constant.int 0 - %int9223372036854775807_1837 = torch.constant.int 9223372036854775807 - %int1_1838 = torch.constant.int 1 - %1741 = torch.aten.slice.Tensor %1740, %int1_1835, %int0_1836, %int9223372036854775807_1837, %int1_1838 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1741, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_1839 = torch.constant.int 2 - %int0_1840 = torch.constant.int 0 - %1742 = torch.aten.select.int %1741, %int2_1839, %int0_1840 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1742, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_1841 = torch.constant.int 32 - %1743 = torch.aten.mul.int %358, %int32_1841 : !torch.int, !torch.int -> !torch.int - %int2_1842 = torch.constant.int 2 - %int0_1843 = torch.constant.int 0 - %int1_1844 = torch.constant.int 1 - %1744 = torch.aten.slice.Tensor %1742, %int2_1842, %int0_1843, %1743, %int1_1844 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1744, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_1845 = torch.constant.int 0 - %1745 = torch.aten.clone %1744, %int0_1845 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1745, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int32_1836 = torch.constant.int 32 + %int128_1837 = torch.constant.int 128 + %2060 = torch.prim.ListConstruct %int4_1834, %int1_1835, %int32_1836, %int128_1837 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2061 = torch.aten.view %2045, %2060 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_1838 = torch.constant.int 4 + %int1_1839 = torch.constant.int 1 + %int8_1840 = torch.constant.int 8 + %int128_1841 = torch.constant.int 128 + %2062 = torch.prim.ListConstruct %int4_1838, %int1_1839, %int8_1840, %int128_1841 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2063 = torch.aten.view %2052, %2062 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_1842 = torch.constant.int 4 + %int1_1843 = torch.constant.int 1 + %int8_1844 = torch.constant.int 8 + %int128_1845 = torch.constant.int 128 + %2064 = torch.prim.ListConstruct %int4_1842, %int1_1843, %int8_1844, %int128_1845 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2065 = torch.aten.view %2059, %2064 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> %int1_1846 = torch.constant.int 1 - %1746 = torch.aten.size.int %1741, %int1_1846 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_1847 = torch.constant.int 32 - %1747 = torch.aten.mul.int %1746, %int32_1847 : !torch.int, !torch.int -> !torch.int - %int4_1848 = torch.constant.int 4 - %int8_1849 = torch.constant.int 8 - %int128_1850 = torch.constant.int 128 - %1748 = torch.prim.ListConstruct %int4_1848, %1747, %int8_1849, %int128_1850 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1749 = torch.aten._unsafe_view %1745, %1748 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1749, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_1851 = torch.constant.int 0 - %int0_1852 = torch.constant.int 0 - %int9223372036854775807_1853 = torch.constant.int 9223372036854775807 - %int1_1854 = torch.constant.int 1 - %1750 = torch.aten.slice.Tensor %1749, %int0_1851, %int0_1852, %int9223372036854775807_1853, %int1_1854 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1750, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_1855 = torch.constant.int 0 - %int0_1856 = torch.constant.int 0 - %int9223372036854775807_1857 = torch.constant.int 9223372036854775807 + %int2_1847 = torch.constant.int 2 + %2066 = torch.aten.transpose.int %2061, %int1_1846, %int2_1847 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %2067 = torch.aten.mul.Tensor %2066, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_1848 = torch.constant.int 3 + %int0_1849 = torch.constant.int 0 + %int64_1850 = torch.constant.int 64 + %int1_1851 = torch.constant.int 1 + %2068 = torch.aten.slice.Tensor %2066, %int3_1848, %int0_1849, %int64_1850, %int1_1851 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_1852 = torch.constant.int 3 + %int64_1853 = torch.constant.int 64 + %int9223372036854775807_1854 = torch.constant.int 9223372036854775807 + %int1_1855 = torch.constant.int 1 + %2069 = torch.aten.slice.Tensor %2066, %int3_1852, %int64_1853, %int9223372036854775807_1854, %int1_1855 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %2070 = torch.aten.neg %2069 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %2071 = torch.prim.ListConstruct %2070, %2068 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_1856 = torch.constant.int -1 + %2072 = torch.aten.cat %2071, %int-1_1856 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %2073 = torch.aten.mul.Tensor %2072, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_1857 = torch.constant.int 1 + %2074 = torch.aten.add.Tensor %2067, %2073, %int1_1857 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_1858 = torch.constant.int 1 - %1751 = torch.aten.slice.Tensor %1739, %int0_1855, %int0_1856, %int9223372036854775807_1857, %int1_1858 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1751, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_1859 = torch.constant.int 1 - %int0_1860 = torch.constant.int 0 - %int9223372036854775807_1861 = torch.constant.int 9223372036854775807 - %int1_1862 = torch.constant.int 1 - %1752 = torch.aten.slice.Tensor %1751, %int1_1859, %int0_1860, %int9223372036854775807_1861, %int1_1862 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1752, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_1863 = torch.constant.int 2 - %int1_1864 = torch.constant.int 1 - %1753 = torch.aten.select.int %1752, %int2_1863, %int1_1864 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1753, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_1865 = torch.constant.int 2 - %int0_1866 = torch.constant.int 0 - %int1_1867 = torch.constant.int 1 - %1754 = torch.aten.slice.Tensor %1753, %int2_1865, %int0_1866, %1743, %int1_1867 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1754, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_1868 = torch.constant.int 0 - %1755 = torch.aten.clone %1754, %int0_1868 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1755, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int2_1859 = torch.constant.int 2 + %2075 = torch.aten.transpose.int %2074, %int1_1858, %int2_1859 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_1860 = torch.constant.int 1 + %int2_1861 = torch.constant.int 2 + %2076 = torch.aten.transpose.int %2063, %int1_1860, %int2_1861 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %2077 = torch.aten.mul.Tensor %2076, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_1862 = torch.constant.int 3 + %int0_1863 = torch.constant.int 0 + %int64_1864 = torch.constant.int 64 + %int1_1865 = torch.constant.int 1 + %2078 = torch.aten.slice.Tensor %2076, %int3_1862, %int0_1863, %int64_1864, %int1_1865 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_1866 = torch.constant.int 3 + %int64_1867 = torch.constant.int 64 + %int9223372036854775807_1868 = torch.constant.int 9223372036854775807 %int1_1869 = torch.constant.int 1 - %1756 = torch.aten.size.int %1752, %int1_1869 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_1870 = torch.constant.int 32 - %1757 = torch.aten.mul.int %1756, %int32_1870 : !torch.int, !torch.int -> !torch.int - %int4_1871 = torch.constant.int 4 - %int8_1872 = torch.constant.int 8 - %int128_1873 = torch.constant.int 128 - %1758 = torch.prim.ListConstruct %int4_1871, %1757, %int8_1872, %int128_1873 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1759 = torch.aten._unsafe_view %1755, %1758 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1759, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_1874 = torch.constant.int 0 - %int0_1875 = torch.constant.int 0 - %int9223372036854775807_1876 = torch.constant.int 9223372036854775807 - %int1_1877 = torch.constant.int 1 - %1760 = torch.aten.slice.Tensor %1759, %int0_1874, %int0_1875, %int9223372036854775807_1876, %int1_1877 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1760, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_1878 = torch.constant.int -2 - %1761 = torch.aten.unsqueeze %1750, %int-2_1878 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1761, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %2079 = torch.aten.slice.Tensor %2076, %int3_1866, %int64_1867, %int9223372036854775807_1868, %int1_1869 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %2080 = torch.aten.neg %2079 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %2081 = torch.prim.ListConstruct %2080, %2078 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_1870 = torch.constant.int -1 + %2082 = torch.aten.cat %2081, %int-1_1870 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %2083 = torch.aten.mul.Tensor %2082, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_1871 = torch.constant.int 1 + %2084 = torch.aten.add.Tensor %2077, %2083, %int1_1871 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_1872 = torch.constant.int 1 + %int2_1873 = torch.constant.int 2 + %2085 = torch.aten.transpose.int %2084, %int1_1872, %int2_1873 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_1874 = torch.constant.int 32 + %2086 = torch.aten.floor_divide.Scalar %arg2, %int32_1874 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_1875 = torch.constant.int 1 + %2087 = torch.aten.unsqueeze %2086, %int1_1875 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_1876 = torch.constant.int 1 + %false_1877 = torch.constant.bool false + %2088 = torch.aten.gather %arg3, %int1_1876, %2087, %false_1877 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_1878 = torch.constant.int 4 %int1_1879 = torch.constant.int 1 - %1762 = torch.aten.size.int %1749, %int1_1879 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_1880 = torch.constant.int 4 - %int8_1881 = torch.constant.int 8 + %int1_1880 = torch.constant.int 1 + %2089 = torch.prim.ListConstruct %int4_1878, %int1_1879, %int1_1880 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2090 = torch.aten.view %2088, %2089 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_1881 = torch.constant.int 32 + %2091 = torch.aten.remainder.Scalar %arg2, %int32_1881 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> %int4_1882 = torch.constant.int 4 - %int128_1883 = torch.constant.int 128 - %1763 = torch.prim.ListConstruct %int4_1880, %1762, %int8_1881, %int4_1882, %int128_1883 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1884 = torch.constant.bool false - %1764 = torch.aten.expand %1761, %1763, %false_1884 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1764, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1885 = torch.constant.int 0 - %1765 = torch.aten.clone %1764, %int0_1885 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1765, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_1886 = torch.constant.int 4 - %int32_1887 = torch.constant.int 32 - %int128_1888 = torch.constant.int 128 - %1766 = torch.prim.ListConstruct %int4_1886, %1762, %int32_1887, %int128_1888 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1767 = torch.aten._unsafe_view %1765, %1766 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1767, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_1889 = torch.constant.int -2 - %1768 = torch.aten.unsqueeze %1760, %int-2_1889 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1768, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int1_1883 = torch.constant.int 1 + %int1_1884 = torch.constant.int 1 + %2092 = torch.prim.ListConstruct %int4_1882, %int1_1883, %int1_1884 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2093 = torch.aten.view %2091, %2092 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_1885 = torch.constant.int 8 + %none_1886 = torch.constant.none + %none_1887 = torch.constant.none + %cpu_1888 = torch.constant.device "cpu" + %false_1889 = torch.constant.bool false + %2094 = torch.aten.arange %int8_1885, %none_1886, %none_1887, %cpu_1888, %false_1889 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> %int1_1890 = torch.constant.int 1 - %1769 = torch.aten.size.int %1759, %int1_1890 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_1891 = torch.constant.int 4 + %int1_1891 = torch.constant.int 1 %int8_1892 = torch.constant.int 8 - %int4_1893 = torch.constant.int 4 - %int128_1894 = torch.constant.int 128 - %1770 = torch.prim.ListConstruct %int4_1891, %1769, %int8_1892, %int4_1893, %int128_1894 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_1895 = torch.constant.bool false - %1771 = torch.aten.expand %1768, %1770, %false_1895 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1771, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_1896 = torch.constant.int 0 - %1772 = torch.aten.clone %1771, %int0_1896 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1772, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_1897 = torch.constant.int 4 - %int32_1898 = torch.constant.int 32 - %int128_1899 = torch.constant.int 128 - %1773 = torch.prim.ListConstruct %int4_1897, %1769, %int32_1898, %int128_1899 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1774 = torch.aten._unsafe_view %1772, %1773 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1774, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %2095 = torch.prim.ListConstruct %int1_1890, %int1_1891, %int8_1892 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2096 = torch.aten.view %2094, %2095 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_1893 = torch.constant.none + %2097 = torch.aten.clone %104, %none_1893 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2098 = torch.aten.detach %2097 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2099 = torch.aten.detach %2098 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2100 = torch.aten.detach %2099 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_1894 = torch.constant.int 1 + %int1_1895 = torch.constant.int 1 + %int1_1896 = torch.constant.int 1 + %2101 = torch.prim.ListConstruct %int1_1894, %int1_1895, %int1_1896 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2102 = torch.aten.view %2100, %2101 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_1897 = torch.constant.int 32 + %2103 = torch.aten.mul.Scalar %2090, %int32_1897 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int7 = torch.constant.int 7 + %int1_1898 = torch.constant.int 1 + %2104 = torch.aten.add.Scalar %2103, %int7, %int1_1898 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_1899 = torch.constant.int 2 + %2105 = torch.aten.mul.Scalar %2104, %int2_1899 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_1900 = torch.constant.int 1 - %int2_1901 = torch.constant.int 2 - %1775 = torch.aten.transpose.int %1655, %int1_1900, %int2_1901 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %2106 = torch.aten.add.Tensor %2105, %2102, %int1_1900 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_1901 = torch.constant.int 8 + %2107 = torch.aten.mul.Scalar %2106, %int8_1901 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_1902 = torch.constant.int 1 - %int2_1903 = torch.constant.int 2 - %1776 = torch.aten.transpose.int %1767, %int1_1902, %int2_1903 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1776, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %2108 = torch.aten.add.Tensor %2107, %2096, %int1_1902 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_1903 = torch.constant.int 32 + %2109 = torch.aten.mul.Scalar %2108, %int32_1903 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_1904 = torch.constant.int 1 - %int2_1905 = torch.constant.int 2 - %1777 = torch.aten.transpose.int %1774, %int1_1904, %int2_1905 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1777, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_1906 = torch.constant.float 0.000000e+00 - %false_1907 = torch.constant.bool false - %none_1908 = torch.constant.none - %1778:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1775, %1776, %1777, %float0.000000e00_1906, %false_1907, %368, %none_1908) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_1909 = torch.constant.int 1 - %int2_1910 = torch.constant.int 2 - %1779 = torch.aten.transpose.int %1778#0, %int1_1909, %int2_1910 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_1911 = torch.constant.int 4 - %int1_1912 = torch.constant.int 1 - %int4096_1913 = torch.constant.int 4096 - %1780 = torch.prim.ListConstruct %int4_1911, %int1_1912, %int4096_1913 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1781 = torch.aten.view %1779, %1780 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_1914 = torch.constant.int -2 - %int-1_1915 = torch.constant.int -1 - %1782 = torch.aten.transpose.int %73, %int-2_1914, %int-1_1915 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1916 = torch.constant.int 4 - %int4096_1917 = torch.constant.int 4096 - %1783 = torch.prim.ListConstruct %int4_1916, %int4096_1917 : (!torch.int, !torch.int) -> !torch.list - %1784 = torch.aten.view %1781, %1783 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1785 = torch.aten.mm %1784, %1782 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_1918 = torch.constant.int 4 - %int1_1919 = torch.constant.int 1 - %int4096_1920 = torch.constant.int 4096 - %1786 = torch.prim.ListConstruct %int4_1918, %int1_1919, %int4096_1920 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1787 = torch.aten.view %1785, %1786 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_1921 = torch.constant.int 1 - %1788 = torch.aten.add.Tensor %1615, %1787, %int1_1921 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_1922 = torch.constant.int 6 - %1789 = torch.prims.convert_element_type %1788, %int6_1922 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_1923 = torch.constant.int 2 - %1790 = torch.aten.pow.Tensor_Scalar %1789, %int2_1923 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_1924 = torch.constant.int -1 - %1791 = torch.prim.ListConstruct %int-1_1924 : (!torch.int) -> !torch.list - %true_1925 = torch.constant.bool true - %none_1926 = torch.constant.none - %1792 = torch.aten.mean.dim %1790, %1791, %true_1925, %none_1926 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_1927 = torch.constant.float 9.9999997473787516E-6 + %2110 = torch.aten.add.Tensor %2109, %2093, %int1_1904 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_1905 = torch.constant.int 5 + %2111 = torch.prims.convert_element_type %2085, %int5_1905 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_1906 = torch.constant.int 32 + %int2_1907 = torch.constant.int 2 + %int8_1908 = torch.constant.int 8 + %int32_1909 = torch.constant.int 32 + %int128_1910 = torch.constant.int 128 + %2112 = torch.prim.ListConstruct %456, %int32_1906, %int2_1907, %int8_1908, %int32_1909, %int128_1910 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2113 = torch.aten.view %1933, %2112 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2113, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_1911 = torch.constant.int 128 + %2114 = torch.prim.ListConstruct %596, %int128_1911 : (!torch.int, !torch.int) -> !torch.list + %2115 = torch.aten.view %2113, %2114 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2115, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %2116 = torch.prim.ListConstruct %2110 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_1912 = torch.constant.bool false + %2117 = torch.aten.index_put %2115, %2116, %2111, %false_1912 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2117, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_1913 = torch.constant.int 32 + %int2_1914 = torch.constant.int 2 + %int8_1915 = torch.constant.int 8 + %int32_1916 = torch.constant.int 32 + %int128_1917 = torch.constant.int 128 + %2118 = torch.prim.ListConstruct %456, %int32_1913, %int2_1914, %int8_1915, %int32_1916, %int128_1917 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2119 = torch.aten.view %2117, %2118 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2119, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_1918 = torch.constant.int 2097152 + %2120 = torch.prim.ListConstruct %456, %int2097152_1918 : (!torch.int, !torch.int) -> !torch.list + %2121 = torch.aten.view %2119, %2120 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2121, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_1919 = torch.constant.int 32 + %int2_1920 = torch.constant.int 2 + %int8_1921 = torch.constant.int 8 + %int32_1922 = torch.constant.int 32 + %int128_1923 = torch.constant.int 128 + %2122 = torch.prim.ListConstruct %456, %int32_1919, %int2_1920, %int8_1921, %int32_1922, %int128_1923 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2123 = torch.aten.view %2121, %2122 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2123, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_1924 = torch.constant.int 128 + %2124 = torch.prim.ListConstruct %596, %int128_1924 : (!torch.int, !torch.int) -> !torch.list + %2125 = torch.aten.view %2123, %2124 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2125, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_1925 = torch.constant.none + %2126 = torch.aten.clone %105, %none_1925 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2127 = torch.aten.detach %2126 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2128 = torch.aten.detach %2127 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2129 = torch.aten.detach %2128 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_1926 = torch.constant.int 1 + %int1_1927 = torch.constant.int 1 %int1_1928 = torch.constant.int 1 - %1793 = torch.aten.add.Scalar %1792, %float9.999990e-06_1927, %int1_1928 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %1794 = torch.aten.rsqrt %1793 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %1795 = torch.aten.mul.Tensor %1789, %1794 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_1929 = torch.constant.int 5 - %1796 = torch.prims.convert_element_type %1795, %int5_1929 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %1797 = torch.aten.mul.Tensor %74, %1796 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_1930 = torch.constant.int 5 - %1798 = torch.prims.convert_element_type %1797, %int5_1930 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_1931 = torch.constant.int -2 - %int-1_1932 = torch.constant.int -1 - %1799 = torch.aten.transpose.int %75, %int-2_1931, %int-1_1932 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1933 = torch.constant.int 4 - %int4096_1934 = torch.constant.int 4096 - %1800 = torch.prim.ListConstruct %int4_1933, %int4096_1934 : (!torch.int, !torch.int) -> !torch.list - %1801 = torch.aten.view %1798, %1800 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1802 = torch.aten.mm %1801, %1799 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_1935 = torch.constant.int 4 - %int1_1936 = torch.constant.int 1 - %int14336_1937 = torch.constant.int 14336 - %1803 = torch.prim.ListConstruct %int4_1935, %int1_1936, %int14336_1937 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1804 = torch.aten.view %1802, %1803 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %1805 = torch.aten.silu %1804 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_1938 = torch.constant.int -2 - %int-1_1939 = torch.constant.int -1 - %1806 = torch.aten.transpose.int %76, %int-2_1938, %int-1_1939 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_1940 = torch.constant.int 4 - %int4096_1941 = torch.constant.int 4096 - %1807 = torch.prim.ListConstruct %int4_1940, %int4096_1941 : (!torch.int, !torch.int) -> !torch.list - %1808 = torch.aten.view %1798, %1807 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1809 = torch.aten.mm %1808, %1806 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_1942 = torch.constant.int 4 - %int1_1943 = torch.constant.int 1 - %int14336_1944 = torch.constant.int 14336 - %1810 = torch.prim.ListConstruct %int4_1942, %int1_1943, %int14336_1944 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1811 = torch.aten.view %1809, %1810 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %1812 = torch.aten.mul.Tensor %1805, %1811 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_1945 = torch.constant.int -2 - %int-1_1946 = torch.constant.int -1 - %1813 = torch.aten.transpose.int %77, %int-2_1945, %int-1_1946 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_1947 = torch.constant.int 4 - %int14336_1948 = torch.constant.int 14336 - %1814 = torch.prim.ListConstruct %int4_1947, %int14336_1948 : (!torch.int, !torch.int) -> !torch.list - %1815 = torch.aten.view %1812, %1814 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %1816 = torch.aten.mm %1815, %1813 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_1949 = torch.constant.int 4 - %int1_1950 = torch.constant.int 1 - %int4096_1951 = torch.constant.int 4096 - %1817 = torch.prim.ListConstruct %int4_1949, %int1_1950, %int4096_1951 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1818 = torch.aten.view %1816, %1817 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_1952 = torch.constant.int 1 - %1819 = torch.aten.add.Tensor %1788, %1818, %int1_1952 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_1953 = torch.constant.int 6 - %1820 = torch.prims.convert_element_type %1819, %int6_1953 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_1954 = torch.constant.int 2 - %1821 = torch.aten.pow.Tensor_Scalar %1820, %int2_1954 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_1955 = torch.constant.int -1 - %1822 = torch.prim.ListConstruct %int-1_1955 : (!torch.int) -> !torch.list - %true_1956 = torch.constant.bool true - %none_1957 = torch.constant.none - %1823 = torch.aten.mean.dim %1821, %1822, %true_1956, %none_1957 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_1958 = torch.constant.float 9.9999997473787516E-6 - %int1_1959 = torch.constant.int 1 - %1824 = torch.aten.add.Scalar %1823, %float9.999990e-06_1958, %int1_1959 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %1825 = torch.aten.rsqrt %1824 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %1826 = torch.aten.mul.Tensor %1820, %1825 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_1960 = torch.constant.int 5 - %1827 = torch.prims.convert_element_type %1826, %int5_1960 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %1828 = torch.aten.mul.Tensor %78, %1827 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_1961 = torch.constant.int 5 - %1829 = torch.prims.convert_element_type %1828, %int5_1961 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_1962 = torch.constant.int -2 - %int-1_1963 = torch.constant.int -1 - %1830 = torch.aten.transpose.int %79, %int-2_1962, %int-1_1963 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_1964 = torch.constant.int 4 - %int4096_1965 = torch.constant.int 4096 - %1831 = torch.prim.ListConstruct %int4_1964, %int4096_1965 : (!torch.int, !torch.int) -> !torch.list - %1832 = torch.aten.view %1829, %1831 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1833 = torch.aten.mm %1832, %1830 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_1966 = torch.constant.int 4 - %int1_1967 = torch.constant.int 1 - %int4096_1968 = torch.constant.int 4096 - %1834 = torch.prim.ListConstruct %int4_1966, %int1_1967, %int4096_1968 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1835 = torch.aten.view %1833, %1834 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_1969 = torch.constant.int -2 - %int-1_1970 = torch.constant.int -1 - %1836 = torch.aten.transpose.int %80, %int-2_1969, %int-1_1970 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %2130 = torch.prim.ListConstruct %int1_1926, %int1_1927, %int1_1928 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2131 = torch.aten.view %2129, %2130 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_1929 = torch.constant.int 32 + %2132 = torch.aten.mul.Scalar %2090, %int32_1929 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int7_1930 = torch.constant.int 7 + %int1_1931 = torch.constant.int 1 + %2133 = torch.aten.add.Scalar %2132, %int7_1930, %int1_1931 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_1932 = torch.constant.int 2 + %2134 = torch.aten.mul.Scalar %2133, %int2_1932 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_1933 = torch.constant.int 1 + %2135 = torch.aten.add.Tensor %2134, %2131, %int1_1933 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_1934 = torch.constant.int 8 + %2136 = torch.aten.mul.Scalar %2135, %int8_1934 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_1935 = torch.constant.int 1 + %2137 = torch.aten.add.Tensor %2136, %2096, %int1_1935 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_1936 = torch.constant.int 32 + %2138 = torch.aten.mul.Scalar %2137, %int32_1936 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_1937 = torch.constant.int 1 + %2139 = torch.aten.add.Tensor %2138, %2093, %int1_1937 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_1938 = torch.constant.int 5 + %2140 = torch.prims.convert_element_type %2065, %int5_1938 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %2141 = torch.prim.ListConstruct %2139 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_1939 = torch.constant.bool false + %2142 = torch.aten.index_put %2125, %2141, %2140, %false_1939 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2142, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_1940 = torch.constant.int 32 + %int2_1941 = torch.constant.int 2 + %int8_1942 = torch.constant.int 8 + %int32_1943 = torch.constant.int 32 + %int128_1944 = torch.constant.int 128 + %2143 = torch.prim.ListConstruct %456, %int32_1940, %int2_1941, %int8_1942, %int32_1943, %int128_1944 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2144 = torch.aten.view %2142, %2143 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2144, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_1945 = torch.constant.int 2097152 + %2145 = torch.prim.ListConstruct %456, %int2097152_1945 : (!torch.int, !torch.int) -> !torch.list + %2146 = torch.aten.view %2144, %2145 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2146, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_1946 = torch.constant.none + %2147 = torch.aten.clone %106, %none_1946 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2148 = torch.aten.detach %2147 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2149 = torch.aten.detach %2148 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2150 = torch.aten.detach %2149 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_1947 = torch.constant.none + %2151 = torch.aten.clone %107, %none_1947 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2152 = torch.aten.detach %2151 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2153 = torch.aten.detach %2152 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2154 = torch.aten.detach %2153 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_1948 = torch.constant.none + %2155 = torch.aten.clone %108, %none_1948 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2156 = torch.aten.detach %2155 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2157 = torch.aten.detach %2156 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2158 = torch.aten.detach %2157 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_1949 = torch.constant.int 32 + %int2_1950 = torch.constant.int 2 + %int8_1951 = torch.constant.int 8 + %int32_1952 = torch.constant.int 32 + %int128_1953 = torch.constant.int 128 + %2159 = torch.prim.ListConstruct %456, %int32_1949, %int2_1950, %int8_1951, %int32_1952, %int128_1953 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2160 = torch.aten.view %2146, %2159 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2160, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %2161 = torch_c.to_builtin_tensor %2160 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %2162 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_1954 = tensor.cast %2162 : tensor<4x?xi64> to tensor + %2163 = torch_c.to_builtin_tensor %2150 : !torch.vtensor<[],si64> -> tensor + %2164 = torch_c.to_builtin_tensor %2154 : !torch.vtensor<[],si64> -> tensor + %2165 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%2161, %cast_1954, %2163, %2164) : (tensor, tensor, tensor, tensor) -> tensor + %cast_1955 = tensor.cast %2165 : tensor to tensor<4x?x8x32x128xf16> + %2166 = torch_c.from_builtin_tensor %cast_1955 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %2166, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %2167 = torch_c.to_builtin_tensor %2160 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %2168 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_1956 = tensor.cast %2168 : tensor<4x?xi64> to tensor + %2169 = torch_c.to_builtin_tensor %2150 : !torch.vtensor<[],si64> -> tensor + %2170 = torch_c.to_builtin_tensor %2158 : !torch.vtensor<[],si64> -> tensor + %2171 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%2167, %cast_1956, %2169, %2170) : (tensor, tensor, tensor, tensor) -> tensor + %cast_1957 = tensor.cast %2171 : tensor to tensor<4x?x8x32x128xf16> + %2172 = torch_c.from_builtin_tensor %cast_1957 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %2172, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_1958 = torch.constant.int 2 + %int3_1959 = torch.constant.int 3 + %2173 = torch.aten.transpose.int %2166, %int2_1958, %int3_1959 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2173, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_1960 = torch.constant.int 0 + %2174 = torch.aten.clone %2173, %int0_1960 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2174, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_1961 = torch.constant.int 4 + %int8_1962 = torch.constant.int 8 + %int128_1963 = torch.constant.int 128 + %2175 = torch.prim.ListConstruct %int4_1961, %457, %int8_1962, %int128_1963 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2176 = torch.aten._unsafe_view %2174, %2175 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2176, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_1964 = torch.constant.int 2 + %int3_1965 = torch.constant.int 3 + %2177 = torch.aten.transpose.int %2172, %int2_1964, %int3_1965 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2177, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_1966 = torch.constant.int 0 + %2178 = torch.aten.clone %2177, %int0_1966 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2178, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_1967 = torch.constant.int 4 + %int8_1968 = torch.constant.int 8 + %int128_1969 = torch.constant.int 128 + %2179 = torch.prim.ListConstruct %int4_1967, %457, %int8_1968, %int128_1969 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2180 = torch.aten._unsafe_view %2178, %2179 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2180, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_1970 = torch.constant.int -2 + %2181 = torch.aten.unsqueeze %2176, %int-2_1970 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2181, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> %int4_1971 = torch.constant.int 4 - %int4096_1972 = torch.constant.int 4096 - %1837 = torch.prim.ListConstruct %int4_1971, %int4096_1972 : (!torch.int, !torch.int) -> !torch.list - %1838 = torch.aten.view %1829, %1837 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1839 = torch.aten.mm %1838, %1836 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int8_1972 = torch.constant.int 8 %int4_1973 = torch.constant.int 4 - %int1_1974 = torch.constant.int 1 - %int1024_1975 = torch.constant.int 1024 - %1840 = torch.prim.ListConstruct %int4_1973, %int1_1974, %int1024_1975 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1841 = torch.aten.view %1839, %1840 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_1976 = torch.constant.int -2 - %int-1_1977 = torch.constant.int -1 - %1842 = torch.aten.transpose.int %81, %int-2_1976, %int-1_1977 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_1978 = torch.constant.int 4 - %int4096_1979 = torch.constant.int 4096 - %1843 = torch.prim.ListConstruct %int4_1978, %int4096_1979 : (!torch.int, !torch.int) -> !torch.list - %1844 = torch.aten.view %1829, %1843 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1845 = torch.aten.mm %1844, %1842 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_1980 = torch.constant.int 4 - %int1_1981 = torch.constant.int 1 - %int1024_1982 = torch.constant.int 1024 - %1846 = torch.prim.ListConstruct %int4_1980, %int1_1981, %int1024_1982 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1847 = torch.aten.view %1845, %1846 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int128_1974 = torch.constant.int 128 + %2182 = torch.prim.ListConstruct %int4_1971, %457, %int8_1972, %int4_1973, %int128_1974 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_1975 = torch.constant.bool false + %2183 = torch.aten.expand %2181, %2182, %false_1975 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2183, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_1976 = torch.constant.int 0 + %2184 = torch.aten.clone %2183, %int0_1976 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2184, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_1977 = torch.constant.int 4 + %int32_1978 = torch.constant.int 32 + %int128_1979 = torch.constant.int 128 + %2185 = torch.prim.ListConstruct %int4_1977, %457, %int32_1978, %int128_1979 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2186 = torch.aten._unsafe_view %2184, %2185 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2186, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_1980 = torch.constant.int -2 + %2187 = torch.aten.unsqueeze %2180, %int-2_1980 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2187, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_1981 = torch.constant.int 4 + %int8_1982 = torch.constant.int 8 %int4_1983 = torch.constant.int 4 - %int1_1984 = torch.constant.int 1 - %int32_1985 = torch.constant.int 32 - %int128_1986 = torch.constant.int 128 - %1848 = torch.prim.ListConstruct %int4_1983, %int1_1984, %int32_1985, %int128_1986 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1849 = torch.aten.view %1835, %1848 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int128_1984 = torch.constant.int 128 + %2188 = torch.prim.ListConstruct %int4_1981, %457, %int8_1982, %int4_1983, %int128_1984 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_1985 = torch.constant.bool false + %2189 = torch.aten.expand %2187, %2188, %false_1985 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2189, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_1986 = torch.constant.int 0 + %2190 = torch.aten.clone %2189, %int0_1986 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2190, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_1987 = torch.constant.int 4 - %int1_1988 = torch.constant.int 1 - %int8_1989 = torch.constant.int 8 - %int128_1990 = torch.constant.int 128 - %1850 = torch.prim.ListConstruct %int4_1987, %int1_1988, %int8_1989, %int128_1990 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1851 = torch.aten.view %1841, %1850 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_1991 = torch.constant.int 4 + %int32_1988 = torch.constant.int 32 + %int128_1989 = torch.constant.int 128 + %2191 = torch.prim.ListConstruct %int4_1987, %457, %int32_1988, %int128_1989 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2192 = torch.aten._unsafe_view %2190, %2191 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2192, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_1990 = torch.constant.int 1 + %int2_1991 = torch.constant.int 2 + %2193 = torch.aten.transpose.int %2075, %int1_1990, %int2_1991 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_1992 = torch.constant.int 1 - %int8_1993 = torch.constant.int 8 - %int128_1994 = torch.constant.int 128 - %1852 = torch.prim.ListConstruct %int4_1991, %int1_1992, %int8_1993, %int128_1994 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1853 = torch.aten.view %1847, %1852 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_1995 = torch.constant.int 6 - %1854 = torch.prims.convert_element_type %1849, %int6_1995 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %1855 = torch_c.to_builtin_tensor %1854 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %1856 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %1857 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%1855, %1856) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %1858 = torch_c.from_builtin_tensor %1857 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_1996 = torch.constant.int 5 - %1859 = torch.prims.convert_element_type %1858, %int5_1996 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_1997 = torch.constant.int 6 - %1860 = torch.prims.convert_element_type %1851, %int6_1997 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %1861 = torch_c.to_builtin_tensor %1860 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %1862 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %1863 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%1861, %1862) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %1864 = torch_c.from_builtin_tensor %1863 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_1998 = torch.constant.int 5 - %1865 = torch.prims.convert_element_type %1864, %int5_1998 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_1999 = torch.constant.int 32 - %1866 = torch.aten.floor_divide.Scalar %arg2, %int32_1999 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_2000 = torch.constant.int 1 - %1867 = torch.aten.unsqueeze %1866, %int1_2000 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2001 = torch.constant.int 1 - %false_2002 = torch.constant.bool false - %1868 = torch.aten.gather %arg3, %int1_2001, %1867, %false_2002 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_2003 = torch.constant.int 32 - %1869 = torch.aten.remainder.Scalar %arg2, %int32_2003 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_2004 = torch.constant.int 1 - %1870 = torch.aten.unsqueeze %1869, %int1_2004 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_2005 = torch.constant.none - %1871 = torch.aten.clone %82, %none_2005 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_2006 = torch.constant.int 0 - %1872 = torch.aten.unsqueeze %1871, %int0_2006 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %int2_1993 = torch.constant.int 2 + %2194 = torch.aten.transpose.int %2186, %int1_1992, %int2_1993 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2194, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_1994 = torch.constant.int 1 + %int2_1995 = torch.constant.int 2 + %2195 = torch.aten.transpose.int %2192, %int1_1994, %int2_1995 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2195, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_1996 = torch.constant.float 0.000000e+00 + %false_1997 = torch.constant.bool false + %none_1998 = torch.constant.none + %2196:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2193, %2194, %2195, %float0.000000e00_1996, %false_1997, %470, %none_1998) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_1999 = torch.constant.int 1 + %int2_2000 = torch.constant.int 2 + %2197 = torch.aten.transpose.int %2196#0, %int1_1999, %int2_2000 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_2001 = torch.constant.int 4 + %int1_2002 = torch.constant.int 1 + %int4096_2003 = torch.constant.int 4096 + %2198 = torch.prim.ListConstruct %int4_2001, %int1_2002, %int4096_2003 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2199 = torch.aten.view %2197, %2198 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_2004 = torch.constant.int -2 + %int-1_2005 = torch.constant.int -1 + %2200 = torch.aten.transpose.int %109, %int-2_2004, %int-1_2005 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2006 = torch.constant.int 5 + %2201 = torch.prims.convert_element_type %2200, %int5_2006 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> %int4_2007 = torch.constant.int 4 - %int1_2008 = torch.constant.int 1 - %1873 = torch.prim.ListConstruct %int4_2007, %int1_2008 : (!torch.int, !torch.int) -> !torch.list - %int1_2009 = torch.constant.int 1 + %int4096_2008 = torch.constant.int 4096 + %2202 = torch.prim.ListConstruct %int4_2007, %int4096_2008 : (!torch.int, !torch.int) -> !torch.list + %2203 = torch.aten.view %2199, %2202 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2204 = torch.aten.mm %2203, %2201 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_2009 = torch.constant.int 4 %int1_2010 = torch.constant.int 1 - %1874 = torch.prim.ListConstruct %int1_2009, %int1_2010 : (!torch.int, !torch.int) -> !torch.list - %int4_2011 = torch.constant.int 4 - %int0_2012 = torch.constant.int 0 - %cpu_2013 = torch.constant.device "cpu" - %false_2014 = torch.constant.bool false - %1875 = torch.aten.empty_strided %1873, %1874, %int4_2011, %int0_2012, %cpu_2013, %false_2014 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int7 = torch.constant.int 7 - %1876 = torch.aten.fill.Scalar %1875, %int7 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_2015 = torch.constant.int 4 - %int1_2016 = torch.constant.int 1 - %1877 = torch.prim.ListConstruct %int4_2015, %int1_2016 : (!torch.int, !torch.int) -> !torch.list - %1878 = torch.aten.repeat %1872, %1877 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_2017 = torch.constant.int 32 - %1879 = torch.aten.mul.Scalar %1868, %int32_2017 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2018 = torch.constant.int 1 - %1880 = torch.aten.add.Tensor %1879, %1876, %int1_2018 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_2019 = torch.constant.int 2 - %1881 = torch.aten.mul.Scalar %1880, %int2_2019 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2020 = torch.constant.int 1 - %1882 = torch.aten.add.Tensor %1881, %1878, %int1_2020 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_2021 = torch.constant.int 32 - %1883 = torch.aten.mul.Scalar %1882, %int32_2021 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2022 = torch.constant.int 1 - %1884 = torch.aten.add.Tensor %1883, %1870, %int1_2022 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_2023 = torch.constant.int 32 - %int2_2024 = torch.constant.int 2 - %int32_2025 = torch.constant.int 32 - %int8_2026 = torch.constant.int 8 - %int128_2027 = torch.constant.int 128 - %1885 = torch.prim.ListConstruct %437, %int32_2023, %int2_2024, %int32_2025, %int8_2026, %int128_2027 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1886 = torch.aten.view %1722, %1885 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1886, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2028 = torch.constant.int 32 - %1887 = torch.aten.mul.int %437, %int32_2028 : !torch.int, !torch.int -> !torch.int - %int2_2029 = torch.constant.int 2 - %1888 = torch.aten.mul.int %1887, %int2_2029 : !torch.int, !torch.int -> !torch.int - %int32_2030 = torch.constant.int 32 - %1889 = torch.aten.mul.int %1888, %int32_2030 : !torch.int, !torch.int -> !torch.int - %int8_2031 = torch.constant.int 8 - %int128_2032 = torch.constant.int 128 - %1890 = torch.prim.ListConstruct %1889, %int8_2031, %int128_2032 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1891 = torch.aten.view %1886, %1890 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1891, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %1892 = torch.prim.ListConstruct %1884 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_2033 = torch.constant.bool false - %1893 = torch.aten.index_put %1891, %1892, %1865, %false_2033 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1893, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_2034 = torch.constant.int 32 - %int2_2035 = torch.constant.int 2 - %int32_2036 = torch.constant.int 32 - %int8_2037 = torch.constant.int 8 - %int128_2038 = torch.constant.int 128 - %1894 = torch.prim.ListConstruct %437, %int32_2034, %int2_2035, %int32_2036, %int8_2037, %int128_2038 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1895 = torch.aten.view %1893, %1894 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1895, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2039 = torch.constant.int 2097152 - %1896 = torch.prim.ListConstruct %437, %int2097152_2039 : (!torch.int, !torch.int) -> !torch.list - %1897 = torch.aten.view %1895, %1896 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1897, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_2040 = torch.constant.int 32 - %int2_2041 = torch.constant.int 2 - %int32_2042 = torch.constant.int 32 - %int8_2043 = torch.constant.int 8 - %int128_2044 = torch.constant.int 128 - %1898 = torch.prim.ListConstruct %437, %int32_2040, %int2_2041, %int32_2042, %int8_2043, %int128_2044 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1899 = torch.aten.view %1897, %1898 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1899, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_2045 = torch.constant.int 8 - %int128_2046 = torch.constant.int 128 - %1900 = torch.prim.ListConstruct %1889, %int8_2045, %int128_2046 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1901 = torch.aten.view %1899, %1900 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1901, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_2047 = torch.constant.int 32 - %1902 = torch.aten.floor_divide.Scalar %arg2, %int32_2047 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_2048 = torch.constant.int 1 - %1903 = torch.aten.unsqueeze %1902, %int1_2048 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2049 = torch.constant.int 1 - %false_2050 = torch.constant.bool false - %1904 = torch.aten.gather %arg3, %int1_2049, %1903, %false_2050 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_2051 = torch.constant.int 32 - %1905 = torch.aten.remainder.Scalar %arg2, %int32_2051 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_2052 = torch.constant.int 1 - %1906 = torch.aten.unsqueeze %1905, %int1_2052 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_2053 = torch.constant.none - %1907 = torch.aten.clone %83, %none_2053 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_2054 = torch.constant.int 0 - %1908 = torch.aten.unsqueeze %1907, %int0_2054 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_2055 = torch.constant.int 4 - %int1_2056 = torch.constant.int 1 - %1909 = torch.prim.ListConstruct %int4_2055, %int1_2056 : (!torch.int, !torch.int) -> !torch.list - %int1_2057 = torch.constant.int 1 - %int1_2058 = torch.constant.int 1 - %1910 = torch.prim.ListConstruct %int1_2057, %int1_2058 : (!torch.int, !torch.int) -> !torch.list + %int4096_2011 = torch.constant.int 4096 + %2205 = torch.prim.ListConstruct %int4_2009, %int1_2010, %int4096_2011 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2206 = torch.aten.view %2204, %2205 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_2012 = torch.constant.int 1 + %2207 = torch.aten.add.Tensor %2028, %2206, %int1_2012 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_2013 = torch.constant.int 6 + %2208 = torch.prims.convert_element_type %2207, %int6_2013 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_2014 = torch.constant.int 2 + %2209 = torch.aten.pow.Tensor_Scalar %2208, %int2_2014 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_2015 = torch.constant.int -1 + %2210 = torch.prim.ListConstruct %int-1_2015 : (!torch.int) -> !torch.list + %true_2016 = torch.constant.bool true + %none_2017 = torch.constant.none + %2211 = torch.aten.mean.dim %2209, %2210, %true_2016, %none_2017 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_2018 = torch.constant.float 9.9999997473787516E-6 + %int1_2019 = torch.constant.int 1 + %2212 = torch.aten.add.Scalar %2211, %float9.999990e-06_2018, %int1_2019 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %2213 = torch.aten.rsqrt %2212 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %2214 = torch.aten.mul.Tensor %2208, %2213 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_2020 = torch.constant.int 5 + %2215 = torch.prims.convert_element_type %2214, %int5_2020 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %2216 = torch.aten.mul.Tensor %110, %2215 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_2021 = torch.constant.int 5 + %2217 = torch.prims.convert_element_type %2216, %int5_2021 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_2022 = torch.constant.int -2 + %int-1_2023 = torch.constant.int -1 + %2218 = torch.aten.transpose.int %111, %int-2_2022, %int-1_2023 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_2024 = torch.constant.int 5 + %2219 = torch.prims.convert_element_type %2218, %int5_2024 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_2025 = torch.constant.int 4 + %int4096_2026 = torch.constant.int 4096 + %2220 = torch.prim.ListConstruct %int4_2025, %int4096_2026 : (!torch.int, !torch.int) -> !torch.list + %2221 = torch.aten.view %2217, %2220 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2222 = torch.aten.mm %2221, %2219 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_2027 = torch.constant.int 4 + %int1_2028 = torch.constant.int 1 + %int14336_2029 = torch.constant.int 14336 + %2223 = torch.prim.ListConstruct %int4_2027, %int1_2028, %int14336_2029 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2224 = torch.aten.view %2222, %2223 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %2225 = torch.aten.silu %2224 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_2030 = torch.constant.int -2 + %int-1_2031 = torch.constant.int -1 + %2226 = torch.aten.transpose.int %112, %int-2_2030, %int-1_2031 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_2032 = torch.constant.int 5 + %2227 = torch.prims.convert_element_type %2226, %int5_2032 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_2033 = torch.constant.int 4 + %int4096_2034 = torch.constant.int 4096 + %2228 = torch.prim.ListConstruct %int4_2033, %int4096_2034 : (!torch.int, !torch.int) -> !torch.list + %2229 = torch.aten.view %2217, %2228 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2230 = torch.aten.mm %2229, %2227 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_2035 = torch.constant.int 4 + %int1_2036 = torch.constant.int 1 + %int14336_2037 = torch.constant.int 14336 + %2231 = torch.prim.ListConstruct %int4_2035, %int1_2036, %int14336_2037 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2232 = torch.aten.view %2230, %2231 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %2233 = torch.aten.mul.Tensor %2225, %2232 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_2038 = torch.constant.int -2 + %int-1_2039 = torch.constant.int -1 + %2234 = torch.aten.transpose.int %113, %int-2_2038, %int-1_2039 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_2040 = torch.constant.int 5 + %2235 = torch.prims.convert_element_type %2234, %int5_2040 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_2041 = torch.constant.int 4 + %int14336_2042 = torch.constant.int 14336 + %2236 = torch.prim.ListConstruct %int4_2041, %int14336_2042 : (!torch.int, !torch.int) -> !torch.list + %2237 = torch.aten.view %2233, %2236 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %2238 = torch.aten.mm %2237, %2235 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_2043 = torch.constant.int 4 + %int1_2044 = torch.constant.int 1 + %int4096_2045 = torch.constant.int 4096 + %2239 = torch.prim.ListConstruct %int4_2043, %int1_2044, %int4096_2045 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2240 = torch.aten.view %2238, %2239 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_2046 = torch.constant.int 1 + %2241 = torch.aten.add.Tensor %2207, %2240, %int1_2046 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_2047 = torch.constant.int 6 + %2242 = torch.prims.convert_element_type %2241, %int6_2047 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_2048 = torch.constant.int 2 + %2243 = torch.aten.pow.Tensor_Scalar %2242, %int2_2048 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_2049 = torch.constant.int -1 + %2244 = torch.prim.ListConstruct %int-1_2049 : (!torch.int) -> !torch.list + %true_2050 = torch.constant.bool true + %none_2051 = torch.constant.none + %2245 = torch.aten.mean.dim %2243, %2244, %true_2050, %none_2051 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_2052 = torch.constant.float 9.9999997473787516E-6 + %int1_2053 = torch.constant.int 1 + %2246 = torch.aten.add.Scalar %2245, %float9.999990e-06_2052, %int1_2053 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %2247 = torch.aten.rsqrt %2246 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %2248 = torch.aten.mul.Tensor %2242, %2247 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_2054 = torch.constant.int 5 + %2249 = torch.prims.convert_element_type %2248, %int5_2054 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %2250 = torch.aten.mul.Tensor %114, %2249 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_2055 = torch.constant.int 5 + %2251 = torch.prims.convert_element_type %2250, %int5_2055 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_2056 = torch.constant.int -2 + %int-1_2057 = torch.constant.int -1 + %2252 = torch.aten.transpose.int %115, %int-2_2056, %int-1_2057 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2058 = torch.constant.int 5 + %2253 = torch.prims.convert_element_type %2252, %int5_2058 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> %int4_2059 = torch.constant.int 4 - %int0_2060 = torch.constant.int 0 - %cpu_2061 = torch.constant.device "cpu" - %false_2062 = torch.constant.bool false - %1911 = torch.aten.empty_strided %1909, %1910, %int4_2059, %int0_2060, %cpu_2061, %false_2062 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int7_2063 = torch.constant.int 7 - %1912 = torch.aten.fill.Scalar %1911, %int7_2063 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_2064 = torch.constant.int 4 - %int1_2065 = torch.constant.int 1 - %1913 = torch.prim.ListConstruct %int4_2064, %int1_2065 : (!torch.int, !torch.int) -> !torch.list - %1914 = torch.aten.repeat %1908, %1913 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_2066 = torch.constant.int 32 - %1915 = torch.aten.mul.Scalar %1904, %int32_2066 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2067 = torch.constant.int 1 - %1916 = torch.aten.add.Tensor %1915, %1912, %int1_2067 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_2068 = torch.constant.int 2 - %1917 = torch.aten.mul.Scalar %1916, %int2_2068 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2069 = torch.constant.int 1 - %1918 = torch.aten.add.Tensor %1917, %1914, %int1_2069 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_2070 = torch.constant.int 32 - %1919 = torch.aten.mul.Scalar %1918, %int32_2070 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2071 = torch.constant.int 1 - %1920 = torch.aten.add.Tensor %1919, %1906, %int1_2071 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %1921 = torch.prim.ListConstruct %1920 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_2072 = torch.constant.bool false - %1922 = torch.aten.index_put %1901, %1921, %1853, %false_2072 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %1922, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_2073 = torch.constant.int 32 - %int2_2074 = torch.constant.int 2 - %int32_2075 = torch.constant.int 32 - %int8_2076 = torch.constant.int 8 - %int128_2077 = torch.constant.int 128 - %1923 = torch.prim.ListConstruct %437, %int32_2073, %int2_2074, %int32_2075, %int8_2076, %int128_2077 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1924 = torch.aten.view %1922, %1923 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1924, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2078 = torch.constant.int 2097152 - %1925 = torch.prim.ListConstruct %437, %int2097152_2078 : (!torch.int, !torch.int) -> !torch.list - %1926 = torch.aten.view %1924, %1925 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %1926, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_2079 = torch.constant.int 4 - %1927 = torch.prim.ListConstruct %int4_2079, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_2080 = torch.constant.int 1 - %1928 = torch.prim.ListConstruct %358, %int1_2080 : (!torch.int, !torch.int) -> !torch.list - %int4_2081 = torch.constant.int 4 - %int0_2082 = torch.constant.int 0 - %cpu_2083 = torch.constant.device "cpu" - %false_2084 = torch.constant.bool false - %1929 = torch.aten.empty_strided %1927, %1928, %int4_2081, %int0_2082, %cpu_2083, %false_2084 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1929, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int7_2085 = torch.constant.int 7 - %1930 = torch.aten.fill.Scalar %1929, %int7_2085 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1930, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_2086 = torch.constant.int 32 - %1931 = torch.aten.mul.Scalar %arg3, %int32_2086 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1931, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_2087 = torch.constant.int 1 - %1932 = torch.aten.add.Tensor %1931, %1930, %int1_2087 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %1932, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int4096_2060 = torch.constant.int 4096 + %2254 = torch.prim.ListConstruct %int4_2059, %int4096_2060 : (!torch.int, !torch.int) -> !torch.list + %2255 = torch.aten.view %2251, %2254 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2256 = torch.aten.mm %2255, %2253 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_2061 = torch.constant.int 4 + %int1_2062 = torch.constant.int 1 + %int4096_2063 = torch.constant.int 4096 + %2257 = torch.prim.ListConstruct %int4_2061, %int1_2062, %int4096_2063 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2258 = torch.aten.view %2256, %2257 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_2064 = torch.constant.int -2 + %int-1_2065 = torch.constant.int -1 + %2259 = torch.aten.transpose.int %116, %int-2_2064, %int-1_2065 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_2066 = torch.constant.int 5 + %2260 = torch.prims.convert_element_type %2259, %int5_2066 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_2067 = torch.constant.int 4 + %int4096_2068 = torch.constant.int 4096 + %2261 = torch.prim.ListConstruct %int4_2067, %int4096_2068 : (!torch.int, !torch.int) -> !torch.list + %2262 = torch.aten.view %2251, %2261 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2263 = torch.aten.mm %2262, %2260 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_2069 = torch.constant.int 4 + %int1_2070 = torch.constant.int 1 + %int1024_2071 = torch.constant.int 1024 + %2264 = torch.prim.ListConstruct %int4_2069, %int1_2070, %int1024_2071 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2265 = torch.aten.view %2263, %2264 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_2072 = torch.constant.int -2 + %int-1_2073 = torch.constant.int -1 + %2266 = torch.aten.transpose.int %117, %int-2_2072, %int-1_2073 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_2074 = torch.constant.int 5 + %2267 = torch.prims.convert_element_type %2266, %int5_2074 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_2075 = torch.constant.int 4 + %int4096_2076 = torch.constant.int 4096 + %2268 = torch.prim.ListConstruct %int4_2075, %int4096_2076 : (!torch.int, !torch.int) -> !torch.list + %2269 = torch.aten.view %2251, %2268 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2270 = torch.aten.mm %2269, %2267 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_2077 = torch.constant.int 4 + %int1_2078 = torch.constant.int 1 + %int1024_2079 = torch.constant.int 1024 + %2271 = torch.prim.ListConstruct %int4_2077, %int1_2078, %int1024_2079 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2272 = torch.aten.view %2270, %2271 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_2080 = torch.constant.int 4 + %int1_2081 = torch.constant.int 1 + %int32_2082 = torch.constant.int 32 + %int128_2083 = torch.constant.int 128 + %2273 = torch.prim.ListConstruct %int4_2080, %int1_2081, %int32_2082, %int128_2083 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2274 = torch.aten.view %2258, %2273 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_2084 = torch.constant.int 4 + %int1_2085 = torch.constant.int 1 + %int8_2086 = torch.constant.int 8 + %int128_2087 = torch.constant.int 128 + %2275 = torch.prim.ListConstruct %int4_2084, %int1_2085, %int8_2086, %int128_2087 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2276 = torch.aten.view %2265, %2275 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> %int4_2088 = torch.constant.int 4 - %1933 = torch.aten.mul.int %int4_2088, %358 : !torch.int, !torch.int -> !torch.int - %1934 = torch.prim.ListConstruct %1933 : (!torch.int) -> !torch.list - %1935 = torch.aten.view %1932, %1934 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %1935, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_2089 = torch.constant.int 32 - %int2_2090 = torch.constant.int 2 - %int32_2091 = torch.constant.int 32 - %int8_2092 = torch.constant.int 8 - %int128_2093 = torch.constant.int 128 - %1936 = torch.prim.ListConstruct %437, %int32_2089, %int2_2090, %int32_2091, %int8_2092, %int128_2093 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1937 = torch.aten.view %1926, %1936 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %1937, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2094 = torch.constant.int 32 - %1938 = torch.aten.mul.int %437, %int32_2094 : !torch.int, !torch.int -> !torch.int - %int2_2095 = torch.constant.int 2 - %int32_2096 = torch.constant.int 32 - %int8_2097 = torch.constant.int 8 - %int128_2098 = torch.constant.int 128 - %1939 = torch.prim.ListConstruct %1938, %int2_2095, %int32_2096, %int8_2097, %int128_2098 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1940 = torch.aten.view %1937, %1939 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %1940, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_2099 = torch.constant.int 0 - %1941 = torch.aten.index_select %1940, %int0_2099, %1935 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %1941, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_2100 = torch.constant.int 4 - %int2_2101 = torch.constant.int 2 - %int32_2102 = torch.constant.int 32 - %int8_2103 = torch.constant.int 8 - %int128_2104 = torch.constant.int 128 - %1942 = torch.prim.ListConstruct %int4_2100, %358, %int2_2101, %int32_2102, %int8_2103, %int128_2104 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1943 = torch.aten.view %1941, %1942 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1943, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_2105 = torch.constant.int 0 - %int0_2106 = torch.constant.int 0 - %int9223372036854775807_2107 = torch.constant.int 9223372036854775807 - %int1_2108 = torch.constant.int 1 - %1944 = torch.aten.slice.Tensor %1943, %int0_2105, %int0_2106, %int9223372036854775807_2107, %int1_2108 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1944, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_2109 = torch.constant.int 1 - %int0_2110 = torch.constant.int 0 - %int9223372036854775807_2111 = torch.constant.int 9223372036854775807 - %int1_2112 = torch.constant.int 1 - %1945 = torch.aten.slice.Tensor %1944, %int1_2109, %int0_2110, %int9223372036854775807_2111, %int1_2112 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1945, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_2113 = torch.constant.int 2 - %int0_2114 = torch.constant.int 0 - %1946 = torch.aten.select.int %1945, %int2_2113, %int0_2114 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1946, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_2115 = torch.constant.int 32 - %1947 = torch.aten.mul.int %358, %int32_2115 : !torch.int, !torch.int -> !torch.int - %int2_2116 = torch.constant.int 2 - %int0_2117 = torch.constant.int 0 + %int1_2089 = torch.constant.int 1 + %int8_2090 = torch.constant.int 8 + %int128_2091 = torch.constant.int 128 + %2277 = torch.prim.ListConstruct %int4_2088, %int1_2089, %int8_2090, %int128_2091 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2278 = torch.aten.view %2272, %2277 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_2092 = torch.constant.int 1 + %int2_2093 = torch.constant.int 2 + %2279 = torch.aten.transpose.int %2274, %int1_2092, %int2_2093 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %2280 = torch.aten.mul.Tensor %2279, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_2094 = torch.constant.int 3 + %int0_2095 = torch.constant.int 0 + %int64_2096 = torch.constant.int 64 + %int1_2097 = torch.constant.int 1 + %2281 = torch.aten.slice.Tensor %2279, %int3_2094, %int0_2095, %int64_2096, %int1_2097 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_2098 = torch.constant.int 3 + %int64_2099 = torch.constant.int 64 + %int9223372036854775807_2100 = torch.constant.int 9223372036854775807 + %int1_2101 = torch.constant.int 1 + %2282 = torch.aten.slice.Tensor %2279, %int3_2098, %int64_2099, %int9223372036854775807_2100, %int1_2101 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %2283 = torch.aten.neg %2282 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %2284 = torch.prim.ListConstruct %2283, %2281 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_2102 = torch.constant.int -1 + %2285 = torch.aten.cat %2284, %int-1_2102 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %2286 = torch.aten.mul.Tensor %2285, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_2103 = torch.constant.int 1 + %2287 = torch.aten.add.Tensor %2280, %2286, %int1_2103 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_2104 = torch.constant.int 1 + %int2_2105 = torch.constant.int 2 + %2288 = torch.aten.transpose.int %2287, %int1_2104, %int2_2105 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_2106 = torch.constant.int 1 + %int2_2107 = torch.constant.int 2 + %2289 = torch.aten.transpose.int %2276, %int1_2106, %int2_2107 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %2290 = torch.aten.mul.Tensor %2289, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_2108 = torch.constant.int 3 + %int0_2109 = torch.constant.int 0 + %int64_2110 = torch.constant.int 64 + %int1_2111 = torch.constant.int 1 + %2291 = torch.aten.slice.Tensor %2289, %int3_2108, %int0_2109, %int64_2110, %int1_2111 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_2112 = torch.constant.int 3 + %int64_2113 = torch.constant.int 64 + %int9223372036854775807_2114 = torch.constant.int 9223372036854775807 + %int1_2115 = torch.constant.int 1 + %2292 = torch.aten.slice.Tensor %2289, %int3_2112, %int64_2113, %int9223372036854775807_2114, %int1_2115 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %2293 = torch.aten.neg %2292 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %2294 = torch.prim.ListConstruct %2293, %2291 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_2116 = torch.constant.int -1 + %2295 = torch.aten.cat %2294, %int-1_2116 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %2296 = torch.aten.mul.Tensor %2295, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_2117 = torch.constant.int 1 + %2297 = torch.aten.add.Tensor %2290, %2296, %int1_2117 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> %int1_2118 = torch.constant.int 1 - %1948 = torch.aten.slice.Tensor %1946, %int2_2116, %int0_2117, %1947, %int1_2118 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1948, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_2119 = torch.constant.int 0 - %1949 = torch.aten.clone %1948, %int0_2119 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1949, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_2120 = torch.constant.int 1 - %1950 = torch.aten.size.int %1945, %int1_2120 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_2121 = torch.constant.int 32 - %1951 = torch.aten.mul.int %1950, %int32_2121 : !torch.int, !torch.int -> !torch.int - %int4_2122 = torch.constant.int 4 - %int8_2123 = torch.constant.int 8 - %int128_2124 = torch.constant.int 128 - %1952 = torch.prim.ListConstruct %int4_2122, %1951, %int8_2123, %int128_2124 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1953 = torch.aten._unsafe_view %1949, %1952 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1953, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_2125 = torch.constant.int 0 - %int0_2126 = torch.constant.int 0 - %int9223372036854775807_2127 = torch.constant.int 9223372036854775807 - %int1_2128 = torch.constant.int 1 - %1954 = torch.aten.slice.Tensor %1953, %int0_2125, %int0_2126, %int9223372036854775807_2127, %int1_2128 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1954, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_2129 = torch.constant.int 0 - %int0_2130 = torch.constant.int 0 - %int9223372036854775807_2131 = torch.constant.int 9223372036854775807 - %int1_2132 = torch.constant.int 1 - %1955 = torch.aten.slice.Tensor %1943, %int0_2129, %int0_2130, %int9223372036854775807_2131, %int1_2132 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1955, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_2133 = torch.constant.int 1 - %int0_2134 = torch.constant.int 0 - %int9223372036854775807_2135 = torch.constant.int 9223372036854775807 + %int2_2119 = torch.constant.int 2 + %2298 = torch.aten.transpose.int %2297, %int1_2118, %int2_2119 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_2120 = torch.constant.int 32 + %2299 = torch.aten.floor_divide.Scalar %arg2, %int32_2120 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_2121 = torch.constant.int 1 + %2300 = torch.aten.unsqueeze %2299, %int1_2121 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_2122 = torch.constant.int 1 + %false_2123 = torch.constant.bool false + %2301 = torch.aten.gather %arg3, %int1_2122, %2300, %false_2123 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_2124 = torch.constant.int 4 + %int1_2125 = torch.constant.int 1 + %int1_2126 = torch.constant.int 1 + %2302 = torch.prim.ListConstruct %int4_2124, %int1_2125, %int1_2126 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2303 = torch.aten.view %2301, %2302 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_2127 = torch.constant.int 32 + %2304 = torch.aten.remainder.Scalar %arg2, %int32_2127 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_2128 = torch.constant.int 4 + %int1_2129 = torch.constant.int 1 + %int1_2130 = torch.constant.int 1 + %2305 = torch.prim.ListConstruct %int4_2128, %int1_2129, %int1_2130 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2306 = torch.aten.view %2304, %2305 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_2131 = torch.constant.int 8 + %none_2132 = torch.constant.none + %none_2133 = torch.constant.none + %cpu_2134 = torch.constant.device "cpu" + %false_2135 = torch.constant.bool false + %2307 = torch.aten.arange %int8_2131, %none_2132, %none_2133, %cpu_2134, %false_2135 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> %int1_2136 = torch.constant.int 1 - %1956 = torch.aten.slice.Tensor %1955, %int1_2133, %int0_2134, %int9223372036854775807_2135, %int1_2136 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %1956, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_2137 = torch.constant.int 2 - %int1_2138 = torch.constant.int 1 - %1957 = torch.aten.select.int %1956, %int2_2137, %int1_2138 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1957, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_2139 = torch.constant.int 2 - %int0_2140 = torch.constant.int 0 + %int1_2137 = torch.constant.int 1 + %int8_2138 = torch.constant.int 8 + %2308 = torch.prim.ListConstruct %int1_2136, %int1_2137, %int8_2138 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2309 = torch.aten.view %2307, %2308 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_2139 = torch.constant.none + %2310 = torch.aten.clone %118, %none_2139 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2311 = torch.aten.detach %2310 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2312 = torch.aten.detach %2311 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2313 = torch.aten.detach %2312 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_2140 = torch.constant.int 1 %int1_2141 = torch.constant.int 1 - %1958 = torch.aten.slice.Tensor %1957, %int2_2139, %int0_2140, %1947, %int1_2141 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1958, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_2142 = torch.constant.int 0 - %1959 = torch.aten.clone %1958, %int0_2142 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %1959, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_2143 = torch.constant.int 1 - %1960 = torch.aten.size.int %1956, %int1_2143 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_2144 = torch.constant.int 32 - %1961 = torch.aten.mul.int %1960, %int32_2144 : !torch.int, !torch.int -> !torch.int - %int4_2145 = torch.constant.int 4 - %int8_2146 = torch.constant.int 8 - %int128_2147 = torch.constant.int 128 - %1962 = torch.prim.ListConstruct %int4_2145, %1961, %int8_2146, %int128_2147 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1963 = torch.aten._unsafe_view %1959, %1962 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1963, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_2148 = torch.constant.int 0 - %int0_2149 = torch.constant.int 0 - %int9223372036854775807_2150 = torch.constant.int 9223372036854775807 + %int1_2142 = torch.constant.int 1 + %2314 = torch.prim.ListConstruct %int1_2140, %int1_2141, %int1_2142 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2315 = torch.aten.view %2313, %2314 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_2143 = torch.constant.int 32 + %2316 = torch.aten.mul.Scalar %2303, %int32_2143 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_2144 = torch.constant.int 8 + %int1_2145 = torch.constant.int 1 + %2317 = torch.aten.add.Scalar %2316, %int8_2144, %int1_2145 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_2146 = torch.constant.int 2 + %2318 = torch.aten.mul.Scalar %2317, %int2_2146 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_2147 = torch.constant.int 1 + %2319 = torch.aten.add.Tensor %2318, %2315, %int1_2147 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_2148 = torch.constant.int 8 + %2320 = torch.aten.mul.Scalar %2319, %int8_2148 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_2149 = torch.constant.int 1 + %2321 = torch.aten.add.Tensor %2320, %2309, %int1_2149 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_2150 = torch.constant.int 32 + %2322 = torch.aten.mul.Scalar %2321, %int32_2150 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_2151 = torch.constant.int 1 - %1964 = torch.aten.slice.Tensor %1963, %int0_2148, %int0_2149, %int9223372036854775807_2150, %int1_2151 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %1964, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_2152 = torch.constant.int -2 - %1965 = torch.aten.unsqueeze %1954, %int-2_2152 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1965, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_2153 = torch.constant.int 1 - %1966 = torch.aten.size.int %1953, %int1_2153 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_2154 = torch.constant.int 4 + %2323 = torch.aten.add.Tensor %2322, %2306, %int1_2151 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_2152 = torch.constant.int 5 + %2324 = torch.prims.convert_element_type %2298, %int5_2152 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_2153 = torch.constant.int 32 + %int2_2154 = torch.constant.int 2 %int8_2155 = torch.constant.int 8 - %int4_2156 = torch.constant.int 4 + %int32_2156 = torch.constant.int 32 %int128_2157 = torch.constant.int 128 - %1967 = torch.prim.ListConstruct %int4_2154, %1966, %int8_2155, %int4_2156, %int128_2157 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2158 = torch.constant.bool false - %1968 = torch.aten.expand %1965, %1967, %false_2158 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1968, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_2159 = torch.constant.int 0 - %1969 = torch.aten.clone %1968, %int0_2159 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1969, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_2160 = torch.constant.int 4 - %int32_2161 = torch.constant.int 32 - %int128_2162 = torch.constant.int 128 - %1970 = torch.prim.ListConstruct %int4_2160, %1966, %int32_2161, %int128_2162 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1971 = torch.aten._unsafe_view %1969, %1970 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1971, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_2163 = torch.constant.int -2 - %1972 = torch.aten.unsqueeze %1964, %int-2_2163 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %1972, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_2164 = torch.constant.int 1 - %1973 = torch.aten.size.int %1963, %int1_2164 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_2165 = torch.constant.int 4 - %int8_2166 = torch.constant.int 8 - %int4_2167 = torch.constant.int 4 - %int128_2168 = torch.constant.int 128 - %1974 = torch.prim.ListConstruct %int4_2165, %1973, %int8_2166, %int4_2167, %int128_2168 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2169 = torch.constant.bool false - %1975 = torch.aten.expand %1972, %1974, %false_2169 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1975, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_2170 = torch.constant.int 0 - %1976 = torch.aten.clone %1975, %int0_2170 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %1976, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_2171 = torch.constant.int 4 - %int32_2172 = torch.constant.int 32 - %int128_2173 = torch.constant.int 128 - %1977 = torch.prim.ListConstruct %int4_2171, %1973, %int32_2172, %int128_2173 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1978 = torch.aten._unsafe_view %1976, %1977 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %1978, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %2325 = torch.prim.ListConstruct %456, %int32_2153, %int2_2154, %int8_2155, %int32_2156, %int128_2157 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2326 = torch.aten.view %2146, %2325 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2326, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_2158 = torch.constant.int 128 + %2327 = torch.prim.ListConstruct %596, %int128_2158 : (!torch.int, !torch.int) -> !torch.list + %2328 = torch.aten.view %2326, %2327 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2328, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %2329 = torch.prim.ListConstruct %2323 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_2159 = torch.constant.bool false + %2330 = torch.aten.index_put %2328, %2329, %2324, %false_2159 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2330, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_2160 = torch.constant.int 32 + %int2_2161 = torch.constant.int 2 + %int8_2162 = torch.constant.int 8 + %int32_2163 = torch.constant.int 32 + %int128_2164 = torch.constant.int 128 + %2331 = torch.prim.ListConstruct %456, %int32_2160, %int2_2161, %int8_2162, %int32_2163, %int128_2164 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2332 = torch.aten.view %2330, %2331 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2332, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_2165 = torch.constant.int 2097152 + %2333 = torch.prim.ListConstruct %456, %int2097152_2165 : (!torch.int, !torch.int) -> !torch.list + %2334 = torch.aten.view %2332, %2333 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2334, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_2166 = torch.constant.int 32 + %int2_2167 = torch.constant.int 2 + %int8_2168 = torch.constant.int 8 + %int32_2169 = torch.constant.int 32 + %int128_2170 = torch.constant.int 128 + %2335 = torch.prim.ListConstruct %456, %int32_2166, %int2_2167, %int8_2168, %int32_2169, %int128_2170 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2336 = torch.aten.view %2334, %2335 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2336, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_2171 = torch.constant.int 128 + %2337 = torch.prim.ListConstruct %596, %int128_2171 : (!torch.int, !torch.int) -> !torch.list + %2338 = torch.aten.view %2336, %2337 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2338, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_2172 = torch.constant.none + %2339 = torch.aten.clone %119, %none_2172 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2340 = torch.aten.detach %2339 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2341 = torch.aten.detach %2340 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2342 = torch.aten.detach %2341 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_2173 = torch.constant.int 1 %int1_2174 = torch.constant.int 1 - %int2_2175 = torch.constant.int 2 - %1979 = torch.aten.transpose.int %1859, %int1_2174, %int2_2175 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_2176 = torch.constant.int 1 - %int2_2177 = torch.constant.int 2 - %1980 = torch.aten.transpose.int %1971, %int1_2176, %int2_2177 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1980, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_2175 = torch.constant.int 1 + %2343 = torch.prim.ListConstruct %int1_2173, %int1_2174, %int1_2175 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2344 = torch.aten.view %2342, %2343 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_2176 = torch.constant.int 32 + %2345 = torch.aten.mul.Scalar %2303, %int32_2176 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_2177 = torch.constant.int 8 %int1_2178 = torch.constant.int 1 + %2346 = torch.aten.add.Scalar %2345, %int8_2177, %int1_2178 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> %int2_2179 = torch.constant.int 2 - %1981 = torch.aten.transpose.int %1978, %int1_2178, %int2_2179 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %1981, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_2180 = torch.constant.float 0.000000e+00 - %false_2181 = torch.constant.bool false - %none_2182 = torch.constant.none - %1982:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%1979, %1980, %1981, %float0.000000e00_2180, %false_2181, %368, %none_2182) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_2183 = torch.constant.int 1 - %int2_2184 = torch.constant.int 2 - %1983 = torch.aten.transpose.int %1982#0, %int1_2183, %int2_2184 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_2185 = torch.constant.int 4 - %int1_2186 = torch.constant.int 1 - %int4096_2187 = torch.constant.int 4096 - %1984 = torch.prim.ListConstruct %int4_2185, %int1_2186, %int4096_2187 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1985 = torch.aten.view %1983, %1984 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_2188 = torch.constant.int -2 - %int-1_2189 = torch.constant.int -1 - %1986 = torch.aten.transpose.int %84, %int-2_2188, %int-1_2189 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_2190 = torch.constant.int 4 - %int4096_2191 = torch.constant.int 4096 - %1987 = torch.prim.ListConstruct %int4_2190, %int4096_2191 : (!torch.int, !torch.int) -> !torch.list - %1988 = torch.aten.view %1985, %1987 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %1989 = torch.aten.mm %1988, %1986 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_2192 = torch.constant.int 4 - %int1_2193 = torch.constant.int 1 - %int4096_2194 = torch.constant.int 4096 - %1990 = torch.prim.ListConstruct %int4_2192, %int1_2193, %int4096_2194 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1991 = torch.aten.view %1989, %1990 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_2195 = torch.constant.int 1 - %1992 = torch.aten.add.Tensor %1819, %1991, %int1_2195 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_2196 = torch.constant.int 6 - %1993 = torch.prims.convert_element_type %1992, %int6_2196 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %2347 = torch.aten.mul.Scalar %2346, %int2_2179 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_2180 = torch.constant.int 1 + %2348 = torch.aten.add.Tensor %2347, %2344, %int1_2180 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_2181 = torch.constant.int 8 + %2349 = torch.aten.mul.Scalar %2348, %int8_2181 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_2182 = torch.constant.int 1 + %2350 = torch.aten.add.Tensor %2349, %2309, %int1_2182 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_2183 = torch.constant.int 32 + %2351 = torch.aten.mul.Scalar %2350, %int32_2183 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_2184 = torch.constant.int 1 + %2352 = torch.aten.add.Tensor %2351, %2306, %int1_2184 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_2185 = torch.constant.int 5 + %2353 = torch.prims.convert_element_type %2278, %int5_2185 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %2354 = torch.prim.ListConstruct %2352 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_2186 = torch.constant.bool false + %2355 = torch.aten.index_put %2338, %2354, %2353, %false_2186 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2355, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_2187 = torch.constant.int 32 + %int2_2188 = torch.constant.int 2 + %int8_2189 = torch.constant.int 8 + %int32_2190 = torch.constant.int 32 + %int128_2191 = torch.constant.int 128 + %2356 = torch.prim.ListConstruct %456, %int32_2187, %int2_2188, %int8_2189, %int32_2190, %int128_2191 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2357 = torch.aten.view %2355, %2356 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2357, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_2192 = torch.constant.int 2097152 + %2358 = torch.prim.ListConstruct %456, %int2097152_2192 : (!torch.int, !torch.int) -> !torch.list + %2359 = torch.aten.view %2357, %2358 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2359, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_2193 = torch.constant.none + %2360 = torch.aten.clone %120, %none_2193 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2361 = torch.aten.detach %2360 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2362 = torch.aten.detach %2361 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2363 = torch.aten.detach %2362 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_2194 = torch.constant.none + %2364 = torch.aten.clone %121, %none_2194 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2365 = torch.aten.detach %2364 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2366 = torch.aten.detach %2365 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2367 = torch.aten.detach %2366 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_2195 = torch.constant.none + %2368 = torch.aten.clone %122, %none_2195 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2369 = torch.aten.detach %2368 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2370 = torch.aten.detach %2369 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2371 = torch.aten.detach %2370 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_2196 = torch.constant.int 32 %int2_2197 = torch.constant.int 2 - %1994 = torch.aten.pow.Tensor_Scalar %1993, %int2_2197 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_2198 = torch.constant.int -1 - %1995 = torch.prim.ListConstruct %int-1_2198 : (!torch.int) -> !torch.list - %true_2199 = torch.constant.bool true - %none_2200 = torch.constant.none - %1996 = torch.aten.mean.dim %1994, %1995, %true_2199, %none_2200 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_2201 = torch.constant.float 9.9999997473787516E-6 - %int1_2202 = torch.constant.int 1 - %1997 = torch.aten.add.Scalar %1996, %float9.999990e-06_2201, %int1_2202 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %1998 = torch.aten.rsqrt %1997 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %1999 = torch.aten.mul.Tensor %1993, %1998 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_2203 = torch.constant.int 5 - %2000 = torch.prims.convert_element_type %1999, %int5_2203 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %2001 = torch.aten.mul.Tensor %85, %2000 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_2204 = torch.constant.int 5 - %2002 = torch.prims.convert_element_type %2001, %int5_2204 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_2205 = torch.constant.int -2 - %int-1_2206 = torch.constant.int -1 - %2003 = torch.aten.transpose.int %86, %int-2_2205, %int-1_2206 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_2207 = torch.constant.int 4 - %int4096_2208 = torch.constant.int 4096 - %2004 = torch.prim.ListConstruct %int4_2207, %int4096_2208 : (!torch.int, !torch.int) -> !torch.list - %2005 = torch.aten.view %2002, %2004 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2006 = torch.aten.mm %2005, %2003 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_2209 = torch.constant.int 4 - %int1_2210 = torch.constant.int 1 - %int14336_2211 = torch.constant.int 14336 - %2007 = torch.prim.ListConstruct %int4_2209, %int1_2210, %int14336_2211 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2008 = torch.aten.view %2006, %2007 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %2009 = torch.aten.silu %2008 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_2212 = torch.constant.int -2 - %int-1_2213 = torch.constant.int -1 - %2010 = torch.aten.transpose.int %87, %int-2_2212, %int-1_2213 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int8_2198 = torch.constant.int 8 + %int32_2199 = torch.constant.int 32 + %int128_2200 = torch.constant.int 128 + %2372 = torch.prim.ListConstruct %456, %int32_2196, %int2_2197, %int8_2198, %int32_2199, %int128_2200 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2373 = torch.aten.view %2359, %2372 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2373, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %2374 = torch_c.to_builtin_tensor %2373 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %2375 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_2201 = tensor.cast %2375 : tensor<4x?xi64> to tensor + %2376 = torch_c.to_builtin_tensor %2363 : !torch.vtensor<[],si64> -> tensor + %2377 = torch_c.to_builtin_tensor %2367 : !torch.vtensor<[],si64> -> tensor + %2378 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%2374, %cast_2201, %2376, %2377) : (tensor, tensor, tensor, tensor) -> tensor + %cast_2202 = tensor.cast %2378 : tensor to tensor<4x?x8x32x128xf16> + %2379 = torch_c.from_builtin_tensor %cast_2202 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %2379, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %2380 = torch_c.to_builtin_tensor %2373 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %2381 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_2203 = tensor.cast %2381 : tensor<4x?xi64> to tensor + %2382 = torch_c.to_builtin_tensor %2363 : !torch.vtensor<[],si64> -> tensor + %2383 = torch_c.to_builtin_tensor %2371 : !torch.vtensor<[],si64> -> tensor + %2384 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%2380, %cast_2203, %2382, %2383) : (tensor, tensor, tensor, tensor) -> tensor + %cast_2204 = tensor.cast %2384 : tensor to tensor<4x?x8x32x128xf16> + %2385 = torch_c.from_builtin_tensor %cast_2204 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %2385, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_2205 = torch.constant.int 2 + %int3_2206 = torch.constant.int 3 + %2386 = torch.aten.transpose.int %2379, %int2_2205, %int3_2206 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2386, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_2207 = torch.constant.int 0 + %2387 = torch.aten.clone %2386, %int0_2207 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2387, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_2208 = torch.constant.int 4 + %int8_2209 = torch.constant.int 8 + %int128_2210 = torch.constant.int 128 + %2388 = torch.prim.ListConstruct %int4_2208, %457, %int8_2209, %int128_2210 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2389 = torch.aten._unsafe_view %2387, %2388 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2389, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_2211 = torch.constant.int 2 + %int3_2212 = torch.constant.int 3 + %2390 = torch.aten.transpose.int %2385, %int2_2211, %int3_2212 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2390, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_2213 = torch.constant.int 0 + %2391 = torch.aten.clone %2390, %int0_2213 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2391, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int4_2214 = torch.constant.int 4 - %int4096_2215 = torch.constant.int 4096 - %2011 = torch.prim.ListConstruct %int4_2214, %int4096_2215 : (!torch.int, !torch.int) -> !torch.list - %2012 = torch.aten.view %2002, %2011 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2013 = torch.aten.mm %2012, %2010 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_2216 = torch.constant.int 4 - %int1_2217 = torch.constant.int 1 - %int14336_2218 = torch.constant.int 14336 - %2014 = torch.prim.ListConstruct %int4_2216, %int1_2217, %int14336_2218 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2015 = torch.aten.view %2013, %2014 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %2016 = torch.aten.mul.Tensor %2009, %2015 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_2219 = torch.constant.int -2 - %int-1_2220 = torch.constant.int -1 - %2017 = torch.aten.transpose.int %88, %int-2_2219, %int-1_2220 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_2221 = torch.constant.int 4 - %int14336_2222 = torch.constant.int 14336 - %2018 = torch.prim.ListConstruct %int4_2221, %int14336_2222 : (!torch.int, !torch.int) -> !torch.list - %2019 = torch.aten.view %2016, %2018 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %2020 = torch.aten.mm %2019, %2017 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_2223 = torch.constant.int 4 - %int1_2224 = torch.constant.int 1 - %int4096_2225 = torch.constant.int 4096 - %2021 = torch.prim.ListConstruct %int4_2223, %int1_2224, %int4096_2225 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2022 = torch.aten.view %2020, %2021 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_2226 = torch.constant.int 1 - %2023 = torch.aten.add.Tensor %1992, %2022, %int1_2226 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_2227 = torch.constant.int 6 - %2024 = torch.prims.convert_element_type %2023, %int6_2227 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_2228 = torch.constant.int 2 - %2025 = torch.aten.pow.Tensor_Scalar %2024, %int2_2228 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_2229 = torch.constant.int -1 - %2026 = torch.prim.ListConstruct %int-1_2229 : (!torch.int) -> !torch.list - %true_2230 = torch.constant.bool true - %none_2231 = torch.constant.none - %2027 = torch.aten.mean.dim %2025, %2026, %true_2230, %none_2231 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_2232 = torch.constant.float 9.9999997473787516E-6 - %int1_2233 = torch.constant.int 1 - %2028 = torch.aten.add.Scalar %2027, %float9.999990e-06_2232, %int1_2233 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %2029 = torch.aten.rsqrt %2028 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %2030 = torch.aten.mul.Tensor %2024, %2029 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_2234 = torch.constant.int 5 - %2031 = torch.prims.convert_element_type %2030, %int5_2234 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %2032 = torch.aten.mul.Tensor %89, %2031 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_2235 = torch.constant.int 5 - %2033 = torch.prims.convert_element_type %2032, %int5_2235 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_2236 = torch.constant.int -2 - %int-1_2237 = torch.constant.int -1 - %2034 = torch.aten.transpose.int %90, %int-2_2236, %int-1_2237 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_2238 = torch.constant.int 4 - %int4096_2239 = torch.constant.int 4096 - %2035 = torch.prim.ListConstruct %int4_2238, %int4096_2239 : (!torch.int, !torch.int) -> !torch.list - %2036 = torch.aten.view %2033, %2035 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2037 = torch.aten.mm %2036, %2034 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_2240 = torch.constant.int 4 + %int8_2215 = torch.constant.int 8 + %int128_2216 = torch.constant.int 128 + %2392 = torch.prim.ListConstruct %int4_2214, %457, %int8_2215, %int128_2216 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2393 = torch.aten._unsafe_view %2391, %2392 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2393, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_2217 = torch.constant.int -2 + %2394 = torch.aten.unsqueeze %2389, %int-2_2217 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2394, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_2218 = torch.constant.int 4 + %int8_2219 = torch.constant.int 8 + %int4_2220 = torch.constant.int 4 + %int128_2221 = torch.constant.int 128 + %2395 = torch.prim.ListConstruct %int4_2218, %457, %int8_2219, %int4_2220, %int128_2221 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_2222 = torch.constant.bool false + %2396 = torch.aten.expand %2394, %2395, %false_2222 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2396, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_2223 = torch.constant.int 0 + %2397 = torch.aten.clone %2396, %int0_2223 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2397, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_2224 = torch.constant.int 4 + %int32_2225 = torch.constant.int 32 + %int128_2226 = torch.constant.int 128 + %2398 = torch.prim.ListConstruct %int4_2224, %457, %int32_2225, %int128_2226 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2399 = torch.aten._unsafe_view %2397, %2398 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2399, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_2227 = torch.constant.int -2 + %2400 = torch.aten.unsqueeze %2393, %int-2_2227 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2400, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_2228 = torch.constant.int 4 + %int8_2229 = torch.constant.int 8 + %int4_2230 = torch.constant.int 4 + %int128_2231 = torch.constant.int 128 + %2401 = torch.prim.ListConstruct %int4_2228, %457, %int8_2229, %int4_2230, %int128_2231 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_2232 = torch.constant.bool false + %2402 = torch.aten.expand %2400, %2401, %false_2232 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2402, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_2233 = torch.constant.int 0 + %2403 = torch.aten.clone %2402, %int0_2233 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2403, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_2234 = torch.constant.int 4 + %int32_2235 = torch.constant.int 32 + %int128_2236 = torch.constant.int 128 + %2404 = torch.prim.ListConstruct %int4_2234, %457, %int32_2235, %int128_2236 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2405 = torch.aten._unsafe_view %2403, %2404 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2405, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_2237 = torch.constant.int 1 + %int2_2238 = torch.constant.int 2 + %2406 = torch.aten.transpose.int %2288, %int1_2237, %int2_2238 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_2239 = torch.constant.int 1 + %int2_2240 = torch.constant.int 2 + %2407 = torch.aten.transpose.int %2399, %int1_2239, %int2_2240 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2407, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_2241 = torch.constant.int 1 - %int4096_2242 = torch.constant.int 4096 - %2038 = torch.prim.ListConstruct %int4_2240, %int1_2241, %int4096_2242 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2039 = torch.aten.view %2037, %2038 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_2243 = torch.constant.int -2 - %int-1_2244 = torch.constant.int -1 - %2040 = torch.aten.transpose.int %91, %int-2_2243, %int-1_2244 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_2245 = torch.constant.int 4 - %int4096_2246 = torch.constant.int 4096 - %2041 = torch.prim.ListConstruct %int4_2245, %int4096_2246 : (!torch.int, !torch.int) -> !torch.list - %2042 = torch.aten.view %2033, %2041 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2043 = torch.aten.mm %2042, %2040 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_2247 = torch.constant.int 4 - %int1_2248 = torch.constant.int 1 - %int1024_2249 = torch.constant.int 1024 - %2044 = torch.prim.ListConstruct %int4_2247, %int1_2248, %int1024_2249 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2045 = torch.aten.view %2043, %2044 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_2250 = torch.constant.int -2 - %int-1_2251 = torch.constant.int -1 - %2046 = torch.aten.transpose.int %92, %int-2_2250, %int-1_2251 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_2252 = torch.constant.int 4 - %int4096_2253 = torch.constant.int 4096 - %2047 = torch.prim.ListConstruct %int4_2252, %int4096_2253 : (!torch.int, !torch.int) -> !torch.list - %2048 = torch.aten.view %2033, %2047 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2049 = torch.aten.mm %2048, %2046 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int2_2242 = torch.constant.int 2 + %2408 = torch.aten.transpose.int %2405, %int1_2241, %int2_2242 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2408, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_2243 = torch.constant.float 0.000000e+00 + %false_2244 = torch.constant.bool false + %none_2245 = torch.constant.none + %2409:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2406, %2407, %2408, %float0.000000e00_2243, %false_2244, %470, %none_2245) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_2246 = torch.constant.int 1 + %int2_2247 = torch.constant.int 2 + %2410 = torch.aten.transpose.int %2409#0, %int1_2246, %int2_2247 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_2248 = torch.constant.int 4 + %int1_2249 = torch.constant.int 1 + %int4096_2250 = torch.constant.int 4096 + %2411 = torch.prim.ListConstruct %int4_2248, %int1_2249, %int4096_2250 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2412 = torch.aten.view %2410, %2411 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_2251 = torch.constant.int -2 + %int-1_2252 = torch.constant.int -1 + %2413 = torch.aten.transpose.int %123, %int-2_2251, %int-1_2252 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2253 = torch.constant.int 5 + %2414 = torch.prims.convert_element_type %2413, %int5_2253 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> %int4_2254 = torch.constant.int 4 - %int1_2255 = torch.constant.int 1 - %int1024_2256 = torch.constant.int 1024 - %2050 = torch.prim.ListConstruct %int4_2254, %int1_2255, %int1024_2256 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2051 = torch.aten.view %2049, %2050 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_2257 = torch.constant.int 4 - %int1_2258 = torch.constant.int 1 - %int32_2259 = torch.constant.int 32 - %int128_2260 = torch.constant.int 128 - %2052 = torch.prim.ListConstruct %int4_2257, %int1_2258, %int32_2259, %int128_2260 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2053 = torch.aten.view %2039, %2052 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_2261 = torch.constant.int 4 - %int1_2262 = torch.constant.int 1 - %int8_2263 = torch.constant.int 8 - %int128_2264 = torch.constant.int 128 - %2054 = torch.prim.ListConstruct %int4_2261, %int1_2262, %int8_2263, %int128_2264 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2055 = torch.aten.view %2045, %2054 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_2265 = torch.constant.int 4 + %int4096_2255 = torch.constant.int 4096 + %2415 = torch.prim.ListConstruct %int4_2254, %int4096_2255 : (!torch.int, !torch.int) -> !torch.list + %2416 = torch.aten.view %2412, %2415 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2417 = torch.aten.mm %2416, %2414 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_2256 = torch.constant.int 4 + %int1_2257 = torch.constant.int 1 + %int4096_2258 = torch.constant.int 4096 + %2418 = torch.prim.ListConstruct %int4_2256, %int1_2257, %int4096_2258 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2419 = torch.aten.view %2417, %2418 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_2259 = torch.constant.int 1 + %2420 = torch.aten.add.Tensor %2241, %2419, %int1_2259 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_2260 = torch.constant.int 6 + %2421 = torch.prims.convert_element_type %2420, %int6_2260 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_2261 = torch.constant.int 2 + %2422 = torch.aten.pow.Tensor_Scalar %2421, %int2_2261 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_2262 = torch.constant.int -1 + %2423 = torch.prim.ListConstruct %int-1_2262 : (!torch.int) -> !torch.list + %true_2263 = torch.constant.bool true + %none_2264 = torch.constant.none + %2424 = torch.aten.mean.dim %2422, %2423, %true_2263, %none_2264 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_2265 = torch.constant.float 9.9999997473787516E-6 %int1_2266 = torch.constant.int 1 - %int8_2267 = torch.constant.int 8 - %int128_2268 = torch.constant.int 128 - %2056 = torch.prim.ListConstruct %int4_2265, %int1_2266, %int8_2267, %int128_2268 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2057 = torch.aten.view %2051, %2056 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_2269 = torch.constant.int 6 - %2058 = torch.prims.convert_element_type %2053, %int6_2269 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %2059 = torch_c.to_builtin_tensor %2058 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %2060 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %2061 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%2059, %2060) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %2062 = torch_c.from_builtin_tensor %2061 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_2270 = torch.constant.int 5 - %2063 = torch.prims.convert_element_type %2062, %int5_2270 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_2271 = torch.constant.int 6 - %2064 = torch.prims.convert_element_type %2055, %int6_2271 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %2065 = torch_c.to_builtin_tensor %2064 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %2066 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %2067 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%2065, %2066) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %2068 = torch_c.from_builtin_tensor %2067 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_2272 = torch.constant.int 5 - %2069 = torch.prims.convert_element_type %2068, %int5_2272 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_2273 = torch.constant.int 32 - %2070 = torch.aten.floor_divide.Scalar %arg2, %int32_2273 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_2274 = torch.constant.int 1 - %2071 = torch.aten.unsqueeze %2070, %int1_2274 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %2425 = torch.aten.add.Scalar %2424, %float9.999990e-06_2265, %int1_2266 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %2426 = torch.aten.rsqrt %2425 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %2427 = torch.aten.mul.Tensor %2421, %2426 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_2267 = torch.constant.int 5 + %2428 = torch.prims.convert_element_type %2427, %int5_2267 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %2429 = torch.aten.mul.Tensor %124, %2428 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_2268 = torch.constant.int 5 + %2430 = torch.prims.convert_element_type %2429, %int5_2268 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_2269 = torch.constant.int -2 + %int-1_2270 = torch.constant.int -1 + %2431 = torch.aten.transpose.int %125, %int-2_2269, %int-1_2270 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_2271 = torch.constant.int 5 + %2432 = torch.prims.convert_element_type %2431, %int5_2271 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_2272 = torch.constant.int 4 + %int4096_2273 = torch.constant.int 4096 + %2433 = torch.prim.ListConstruct %int4_2272, %int4096_2273 : (!torch.int, !torch.int) -> !torch.list + %2434 = torch.aten.view %2430, %2433 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2435 = torch.aten.mm %2434, %2432 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_2274 = torch.constant.int 4 %int1_2275 = torch.constant.int 1 - %false_2276 = torch.constant.bool false - %2072 = torch.aten.gather %arg3, %int1_2275, %2071, %false_2276 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_2277 = torch.constant.int 32 - %2073 = torch.aten.remainder.Scalar %arg2, %int32_2277 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_2278 = torch.constant.int 1 - %2074 = torch.aten.unsqueeze %2073, %int1_2278 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_2279 = torch.constant.none - %2075 = torch.aten.clone %93, %none_2279 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_2280 = torch.constant.int 0 - %2076 = torch.aten.unsqueeze %2075, %int0_2280 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_2281 = torch.constant.int 4 - %int1_2282 = torch.constant.int 1 - %2077 = torch.prim.ListConstruct %int4_2281, %int1_2282 : (!torch.int, !torch.int) -> !torch.list + %int14336_2276 = torch.constant.int 14336 + %2436 = torch.prim.ListConstruct %int4_2274, %int1_2275, %int14336_2276 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2437 = torch.aten.view %2435, %2436 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %2438 = torch.aten.silu %2437 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_2277 = torch.constant.int -2 + %int-1_2278 = torch.constant.int -1 + %2439 = torch.aten.transpose.int %126, %int-2_2277, %int-1_2278 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_2279 = torch.constant.int 5 + %2440 = torch.prims.convert_element_type %2439, %int5_2279 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_2280 = torch.constant.int 4 + %int4096_2281 = torch.constant.int 4096 + %2441 = torch.prim.ListConstruct %int4_2280, %int4096_2281 : (!torch.int, !torch.int) -> !torch.list + %2442 = torch.aten.view %2430, %2441 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2443 = torch.aten.mm %2442, %2440 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_2282 = torch.constant.int 4 %int1_2283 = torch.constant.int 1 - %int1_2284 = torch.constant.int 1 - %2078 = torch.prim.ListConstruct %int1_2283, %int1_2284 : (!torch.int, !torch.int) -> !torch.list - %int4_2285 = torch.constant.int 4 - %int0_2286 = torch.constant.int 0 - %cpu_2287 = torch.constant.device "cpu" - %false_2288 = torch.constant.bool false - %2079 = torch.aten.empty_strided %2077, %2078, %int4_2285, %int0_2286, %cpu_2287, %false_2288 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int8_2289 = torch.constant.int 8 - %2080 = torch.aten.fill.Scalar %2079, %int8_2289 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int14336_2284 = torch.constant.int 14336 + %2444 = torch.prim.ListConstruct %int4_2282, %int1_2283, %int14336_2284 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2445 = torch.aten.view %2443, %2444 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %2446 = torch.aten.mul.Tensor %2438, %2445 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_2285 = torch.constant.int -2 + %int-1_2286 = torch.constant.int -1 + %2447 = torch.aten.transpose.int %127, %int-2_2285, %int-1_2286 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_2287 = torch.constant.int 5 + %2448 = torch.prims.convert_element_type %2447, %int5_2287 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_2288 = torch.constant.int 4 + %int14336_2289 = torch.constant.int 14336 + %2449 = torch.prim.ListConstruct %int4_2288, %int14336_2289 : (!torch.int, !torch.int) -> !torch.list + %2450 = torch.aten.view %2446, %2449 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %2451 = torch.aten.mm %2450, %2448 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_2290 = torch.constant.int 4 %int1_2291 = torch.constant.int 1 - %2081 = torch.prim.ListConstruct %int4_2290, %int1_2291 : (!torch.int, !torch.int) -> !torch.list - %2082 = torch.aten.repeat %2076, %2081 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_2292 = torch.constant.int 32 - %2083 = torch.aten.mul.Scalar %2072, %int32_2292 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int4096_2292 = torch.constant.int 4096 + %2452 = torch.prim.ListConstruct %int4_2290, %int1_2291, %int4096_2292 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2453 = torch.aten.view %2451, %2452 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_2293 = torch.constant.int 1 - %2084 = torch.aten.add.Tensor %2083, %2080, %int1_2293 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_2294 = torch.constant.int 2 - %2085 = torch.aten.mul.Scalar %2084, %int2_2294 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2295 = torch.constant.int 1 - %2086 = torch.aten.add.Tensor %2085, %2082, %int1_2295 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_2296 = torch.constant.int 32 - %2087 = torch.aten.mul.Scalar %2086, %int32_2296 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2297 = torch.constant.int 1 - %2088 = torch.aten.add.Tensor %2087, %2074, %int1_2297 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_2298 = torch.constant.int 32 - %int2_2299 = torch.constant.int 2 - %int32_2300 = torch.constant.int 32 - %int8_2301 = torch.constant.int 8 - %int128_2302 = torch.constant.int 128 - %2089 = torch.prim.ListConstruct %437, %int32_2298, %int2_2299, %int32_2300, %int8_2301, %int128_2302 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2090 = torch.aten.view %1926, %2089 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2090, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2303 = torch.constant.int 32 - %2091 = torch.aten.mul.int %437, %int32_2303 : !torch.int, !torch.int -> !torch.int - %int2_2304 = torch.constant.int 2 - %2092 = torch.aten.mul.int %2091, %int2_2304 : !torch.int, !torch.int -> !torch.int - %int32_2305 = torch.constant.int 32 - %2093 = torch.aten.mul.int %2092, %int32_2305 : !torch.int, !torch.int -> !torch.int - %int8_2306 = torch.constant.int 8 - %int128_2307 = torch.constant.int 128 - %2094 = torch.prim.ListConstruct %2093, %int8_2306, %int128_2307 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2095 = torch.aten.view %2090, %2094 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2095, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %2096 = torch.prim.ListConstruct %2088 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_2308 = torch.constant.bool false - %2097 = torch.aten.index_put %2095, %2096, %2069, %false_2308 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2097, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_2309 = torch.constant.int 32 - %int2_2310 = torch.constant.int 2 - %int32_2311 = torch.constant.int 32 - %int8_2312 = torch.constant.int 8 - %int128_2313 = torch.constant.int 128 - %2098 = torch.prim.ListConstruct %437, %int32_2309, %int2_2310, %int32_2311, %int8_2312, %int128_2313 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2099 = torch.aten.view %2097, %2098 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2099, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2314 = torch.constant.int 2097152 - %2100 = torch.prim.ListConstruct %437, %int2097152_2314 : (!torch.int, !torch.int) -> !torch.list - %2101 = torch.aten.view %2099, %2100 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2101, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_2315 = torch.constant.int 32 - %int2_2316 = torch.constant.int 2 - %int32_2317 = torch.constant.int 32 - %int8_2318 = torch.constant.int 8 - %int128_2319 = torch.constant.int 128 - %2102 = torch.prim.ListConstruct %437, %int32_2315, %int2_2316, %int32_2317, %int8_2318, %int128_2319 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2103 = torch.aten.view %2101, %2102 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2103, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_2320 = torch.constant.int 8 - %int128_2321 = torch.constant.int 128 - %2104 = torch.prim.ListConstruct %2093, %int8_2320, %int128_2321 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2105 = torch.aten.view %2103, %2104 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2105, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_2322 = torch.constant.int 32 - %2106 = torch.aten.floor_divide.Scalar %arg2, %int32_2322 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_2323 = torch.constant.int 1 - %2107 = torch.aten.unsqueeze %2106, %int1_2323 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2324 = torch.constant.int 1 - %false_2325 = torch.constant.bool false - %2108 = torch.aten.gather %arg3, %int1_2324, %2107, %false_2325 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_2326 = torch.constant.int 32 - %2109 = torch.aten.remainder.Scalar %arg2, %int32_2326 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_2327 = torch.constant.int 1 - %2110 = torch.aten.unsqueeze %2109, %int1_2327 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_2328 = torch.constant.none - %2111 = torch.aten.clone %94, %none_2328 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_2329 = torch.constant.int 0 - %2112 = torch.aten.unsqueeze %2111, %int0_2329 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_2330 = torch.constant.int 4 - %int1_2331 = torch.constant.int 1 - %2113 = torch.prim.ListConstruct %int4_2330, %int1_2331 : (!torch.int, !torch.int) -> !torch.list + %2454 = torch.aten.add.Tensor %2420, %2453, %int1_2293 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_2294 = torch.constant.int 6 + %2455 = torch.prims.convert_element_type %2454, %int6_2294 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_2295 = torch.constant.int 2 + %2456 = torch.aten.pow.Tensor_Scalar %2455, %int2_2295 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_2296 = torch.constant.int -1 + %2457 = torch.prim.ListConstruct %int-1_2296 : (!torch.int) -> !torch.list + %true_2297 = torch.constant.bool true + %none_2298 = torch.constant.none + %2458 = torch.aten.mean.dim %2456, %2457, %true_2297, %none_2298 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_2299 = torch.constant.float 9.9999997473787516E-6 + %int1_2300 = torch.constant.int 1 + %2459 = torch.aten.add.Scalar %2458, %float9.999990e-06_2299, %int1_2300 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %2460 = torch.aten.rsqrt %2459 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %2461 = torch.aten.mul.Tensor %2455, %2460 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_2301 = torch.constant.int 5 + %2462 = torch.prims.convert_element_type %2461, %int5_2301 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %2463 = torch.aten.mul.Tensor %128, %2462 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_2302 = torch.constant.int 5 + %2464 = torch.prims.convert_element_type %2463, %int5_2302 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_2303 = torch.constant.int -2 + %int-1_2304 = torch.constant.int -1 + %2465 = torch.aten.transpose.int %129, %int-2_2303, %int-1_2304 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2305 = torch.constant.int 5 + %2466 = torch.prims.convert_element_type %2465, %int5_2305 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_2306 = torch.constant.int 4 + %int4096_2307 = torch.constant.int 4096 + %2467 = torch.prim.ListConstruct %int4_2306, %int4096_2307 : (!torch.int, !torch.int) -> !torch.list + %2468 = torch.aten.view %2464, %2467 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2469 = torch.aten.mm %2468, %2466 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_2308 = torch.constant.int 4 + %int1_2309 = torch.constant.int 1 + %int4096_2310 = torch.constant.int 4096 + %2470 = torch.prim.ListConstruct %int4_2308, %int1_2309, %int4096_2310 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2471 = torch.aten.view %2469, %2470 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_2311 = torch.constant.int -2 + %int-1_2312 = torch.constant.int -1 + %2472 = torch.aten.transpose.int %130, %int-2_2311, %int-1_2312 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_2313 = torch.constant.int 5 + %2473 = torch.prims.convert_element_type %2472, %int5_2313 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_2314 = torch.constant.int 4 + %int4096_2315 = torch.constant.int 4096 + %2474 = torch.prim.ListConstruct %int4_2314, %int4096_2315 : (!torch.int, !torch.int) -> !torch.list + %2475 = torch.aten.view %2464, %2474 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2476 = torch.aten.mm %2475, %2473 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_2316 = torch.constant.int 4 + %int1_2317 = torch.constant.int 1 + %int1024_2318 = torch.constant.int 1024 + %2477 = torch.prim.ListConstruct %int4_2316, %int1_2317, %int1024_2318 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2478 = torch.aten.view %2476, %2477 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_2319 = torch.constant.int -2 + %int-1_2320 = torch.constant.int -1 + %2479 = torch.aten.transpose.int %131, %int-2_2319, %int-1_2320 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_2321 = torch.constant.int 5 + %2480 = torch.prims.convert_element_type %2479, %int5_2321 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_2322 = torch.constant.int 4 + %int4096_2323 = torch.constant.int 4096 + %2481 = torch.prim.ListConstruct %int4_2322, %int4096_2323 : (!torch.int, !torch.int) -> !torch.list + %2482 = torch.aten.view %2464, %2481 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2483 = torch.aten.mm %2482, %2480 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_2324 = torch.constant.int 4 + %int1_2325 = torch.constant.int 1 + %int1024_2326 = torch.constant.int 1024 + %2484 = torch.prim.ListConstruct %int4_2324, %int1_2325, %int1024_2326 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2485 = torch.aten.view %2483, %2484 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_2327 = torch.constant.int 4 + %int1_2328 = torch.constant.int 1 + %int32_2329 = torch.constant.int 32 + %int128_2330 = torch.constant.int 128 + %2486 = torch.prim.ListConstruct %int4_2327, %int1_2328, %int32_2329, %int128_2330 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2487 = torch.aten.view %2471, %2486 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_2331 = torch.constant.int 4 %int1_2332 = torch.constant.int 1 - %int1_2333 = torch.constant.int 1 - %2114 = torch.prim.ListConstruct %int1_2332, %int1_2333 : (!torch.int, !torch.int) -> !torch.list - %int4_2334 = torch.constant.int 4 - %int0_2335 = torch.constant.int 0 - %cpu_2336 = torch.constant.device "cpu" - %false_2337 = torch.constant.bool false - %2115 = torch.aten.empty_strided %2113, %2114, %int4_2334, %int0_2335, %cpu_2336, %false_2337 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int8_2338 = torch.constant.int 8 - %2116 = torch.aten.fill.Scalar %2115, %int8_2338 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_2339 = torch.constant.int 4 - %int1_2340 = torch.constant.int 1 - %2117 = torch.prim.ListConstruct %int4_2339, %int1_2340 : (!torch.int, !torch.int) -> !torch.list - %2118 = torch.aten.repeat %2112, %2117 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_2341 = torch.constant.int 32 - %2119 = torch.aten.mul.Scalar %2108, %int32_2341 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2342 = torch.constant.int 1 - %2120 = torch.aten.add.Tensor %2119, %2116, %int1_2342 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_2343 = torch.constant.int 2 - %2121 = torch.aten.mul.Scalar %2120, %int2_2343 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int8_2333 = torch.constant.int 8 + %int128_2334 = torch.constant.int 128 + %2488 = torch.prim.ListConstruct %int4_2331, %int1_2332, %int8_2333, %int128_2334 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2489 = torch.aten.view %2478, %2488 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_2335 = torch.constant.int 4 + %int1_2336 = torch.constant.int 1 + %int8_2337 = torch.constant.int 8 + %int128_2338 = torch.constant.int 128 + %2490 = torch.prim.ListConstruct %int4_2335, %int1_2336, %int8_2337, %int128_2338 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2491 = torch.aten.view %2485, %2490 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_2339 = torch.constant.int 1 + %int2_2340 = torch.constant.int 2 + %2492 = torch.aten.transpose.int %2487, %int1_2339, %int2_2340 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %2493 = torch.aten.mul.Tensor %2492, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_2341 = torch.constant.int 3 + %int0_2342 = torch.constant.int 0 + %int64_2343 = torch.constant.int 64 %int1_2344 = torch.constant.int 1 - %2122 = torch.aten.add.Tensor %2121, %2118, %int1_2344 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_2345 = torch.constant.int 32 - %2123 = torch.aten.mul.Scalar %2122, %int32_2345 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2346 = torch.constant.int 1 - %2124 = torch.aten.add.Tensor %2123, %2110, %int1_2346 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %2125 = torch.prim.ListConstruct %2124 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_2347 = torch.constant.bool false - %2126 = torch.aten.index_put %2105, %2125, %2057, %false_2347 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2126, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_2348 = torch.constant.int 32 - %int2_2349 = torch.constant.int 2 - %int32_2350 = torch.constant.int 32 - %int8_2351 = torch.constant.int 8 - %int128_2352 = torch.constant.int 128 - %2127 = torch.prim.ListConstruct %437, %int32_2348, %int2_2349, %int32_2350, %int8_2351, %int128_2352 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2128 = torch.aten.view %2126, %2127 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2128, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2353 = torch.constant.int 2097152 - %2129 = torch.prim.ListConstruct %437, %int2097152_2353 : (!torch.int, !torch.int) -> !torch.list - %2130 = torch.aten.view %2128, %2129 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2130, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_2354 = torch.constant.int 4 - %2131 = torch.prim.ListConstruct %int4_2354, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_2355 = torch.constant.int 1 - %2132 = torch.prim.ListConstruct %358, %int1_2355 : (!torch.int, !torch.int) -> !torch.list - %int4_2356 = torch.constant.int 4 - %int0_2357 = torch.constant.int 0 - %cpu_2358 = torch.constant.device "cpu" - %false_2359 = torch.constant.bool false - %2133 = torch.aten.empty_strided %2131, %2132, %int4_2356, %int0_2357, %cpu_2358, %false_2359 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2133, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int8_2360 = torch.constant.int 8 - %2134 = torch.aten.fill.Scalar %2133, %int8_2360 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2134, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_2361 = torch.constant.int 32 - %2135 = torch.aten.mul.Scalar %arg3, %int32_2361 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2135, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %2494 = torch.aten.slice.Tensor %2492, %int3_2341, %int0_2342, %int64_2343, %int1_2344 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_2345 = torch.constant.int 3 + %int64_2346 = torch.constant.int 64 + %int9223372036854775807_2347 = torch.constant.int 9223372036854775807 + %int1_2348 = torch.constant.int 1 + %2495 = torch.aten.slice.Tensor %2492, %int3_2345, %int64_2346, %int9223372036854775807_2347, %int1_2348 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %2496 = torch.aten.neg %2495 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %2497 = torch.prim.ListConstruct %2496, %2494 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_2349 = torch.constant.int -1 + %2498 = torch.aten.cat %2497, %int-1_2349 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %2499 = torch.aten.mul.Tensor %2498, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_2350 = torch.constant.int 1 + %2500 = torch.aten.add.Tensor %2493, %2499, %int1_2350 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_2351 = torch.constant.int 1 + %int2_2352 = torch.constant.int 2 + %2501 = torch.aten.transpose.int %2500, %int1_2351, %int2_2352 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_2353 = torch.constant.int 1 + %int2_2354 = torch.constant.int 2 + %2502 = torch.aten.transpose.int %2489, %int1_2353, %int2_2354 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %2503 = torch.aten.mul.Tensor %2502, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_2355 = torch.constant.int 3 + %int0_2356 = torch.constant.int 0 + %int64_2357 = torch.constant.int 64 + %int1_2358 = torch.constant.int 1 + %2504 = torch.aten.slice.Tensor %2502, %int3_2355, %int0_2356, %int64_2357, %int1_2358 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_2359 = torch.constant.int 3 + %int64_2360 = torch.constant.int 64 + %int9223372036854775807_2361 = torch.constant.int 9223372036854775807 %int1_2362 = torch.constant.int 1 - %2136 = torch.aten.add.Tensor %2135, %2134, %int1_2362 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2136, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_2363 = torch.constant.int 4 - %2137 = torch.aten.mul.int %int4_2363, %358 : !torch.int, !torch.int -> !torch.int - %2138 = torch.prim.ListConstruct %2137 : (!torch.int) -> !torch.list - %2139 = torch.aten.view %2136, %2138 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2139, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_2364 = torch.constant.int 32 - %int2_2365 = torch.constant.int 2 - %int32_2366 = torch.constant.int 32 - %int8_2367 = torch.constant.int 8 - %int128_2368 = torch.constant.int 128 - %2140 = torch.prim.ListConstruct %437, %int32_2364, %int2_2365, %int32_2366, %int8_2367, %int128_2368 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2141 = torch.aten.view %2130, %2140 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2141, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2369 = torch.constant.int 32 - %2142 = torch.aten.mul.int %437, %int32_2369 : !torch.int, !torch.int -> !torch.int - %int2_2370 = torch.constant.int 2 - %int32_2371 = torch.constant.int 32 - %int8_2372 = torch.constant.int 8 - %int128_2373 = torch.constant.int 128 - %2143 = torch.prim.ListConstruct %2142, %int2_2370, %int32_2371, %int8_2372, %int128_2373 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2144 = torch.aten.view %2141, %2143 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %2144, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_2374 = torch.constant.int 0 - %2145 = torch.aten.index_select %2144, %int0_2374, %2139 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %2145, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> + %2505 = torch.aten.slice.Tensor %2502, %int3_2359, %int64_2360, %int9223372036854775807_2361, %int1_2362 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %2506 = torch.aten.neg %2505 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %2507 = torch.prim.ListConstruct %2506, %2504 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_2363 = torch.constant.int -1 + %2508 = torch.aten.cat %2507, %int-1_2363 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %2509 = torch.aten.mul.Tensor %2508, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_2364 = torch.constant.int 1 + %2510 = torch.aten.add.Tensor %2503, %2509, %int1_2364 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_2365 = torch.constant.int 1 + %int2_2366 = torch.constant.int 2 + %2511 = torch.aten.transpose.int %2510, %int1_2365, %int2_2366 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_2367 = torch.constant.int 32 + %2512 = torch.aten.floor_divide.Scalar %arg2, %int32_2367 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_2368 = torch.constant.int 1 + %2513 = torch.aten.unsqueeze %2512, %int1_2368 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_2369 = torch.constant.int 1 + %false_2370 = torch.constant.bool false + %2514 = torch.aten.gather %arg3, %int1_2369, %2513, %false_2370 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_2371 = torch.constant.int 4 + %int1_2372 = torch.constant.int 1 + %int1_2373 = torch.constant.int 1 + %2515 = torch.prim.ListConstruct %int4_2371, %int1_2372, %int1_2373 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2516 = torch.aten.view %2514, %2515 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_2374 = torch.constant.int 32 + %2517 = torch.aten.remainder.Scalar %arg2, %int32_2374 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> %int4_2375 = torch.constant.int 4 - %int2_2376 = torch.constant.int 2 - %int32_2377 = torch.constant.int 32 + %int1_2376 = torch.constant.int 1 + %int1_2377 = torch.constant.int 1 + %2518 = torch.prim.ListConstruct %int4_2375, %int1_2376, %int1_2377 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2519 = torch.aten.view %2517, %2518 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> %int8_2378 = torch.constant.int 8 - %int128_2379 = torch.constant.int 128 - %2146 = torch.prim.ListConstruct %int4_2375, %358, %int2_2376, %int32_2377, %int8_2378, %int128_2379 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2147 = torch.aten.view %2145, %2146 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2147, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_2380 = torch.constant.int 0 - %int0_2381 = torch.constant.int 0 - %int9223372036854775807_2382 = torch.constant.int 9223372036854775807 + %none_2379 = torch.constant.none + %none_2380 = torch.constant.none + %cpu_2381 = torch.constant.device "cpu" + %false_2382 = torch.constant.bool false + %2520 = torch.aten.arange %int8_2378, %none_2379, %none_2380, %cpu_2381, %false_2382 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> %int1_2383 = torch.constant.int 1 - %2148 = torch.aten.slice.Tensor %2147, %int0_2380, %int0_2381, %int9223372036854775807_2382, %int1_2383 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2148, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> %int1_2384 = torch.constant.int 1 - %int0_2385 = torch.constant.int 0 - %int9223372036854775807_2386 = torch.constant.int 9223372036854775807 + %int8_2385 = torch.constant.int 8 + %2521 = torch.prim.ListConstruct %int1_2383, %int1_2384, %int8_2385 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2522 = torch.aten.view %2520, %2521 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_2386 = torch.constant.none + %2523 = torch.aten.clone %132, %none_2386 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2524 = torch.aten.detach %2523 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2525 = torch.aten.detach %2524 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2526 = torch.aten.detach %2525 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %int1_2387 = torch.constant.int 1 - %2149 = torch.aten.slice.Tensor %2148, %int1_2384, %int0_2385, %int9223372036854775807_2386, %int1_2387 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2149, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_2388 = torch.constant.int 2 - %int0_2389 = torch.constant.int 0 - %2150 = torch.aten.select.int %2149, %int2_2388, %int0_2389 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2150, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int1_2388 = torch.constant.int 1 + %int1_2389 = torch.constant.int 1 + %2527 = torch.prim.ListConstruct %int1_2387, %int1_2388, %int1_2389 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2528 = torch.aten.view %2526, %2527 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> %int32_2390 = torch.constant.int 32 - %2151 = torch.aten.mul.int %358, %int32_2390 : !torch.int, !torch.int -> !torch.int - %int2_2391 = torch.constant.int 2 - %int0_2392 = torch.constant.int 0 + %2529 = torch.aten.mul.Scalar %2516, %int32_2390 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int9 = torch.constant.int 9 + %int1_2391 = torch.constant.int 1 + %2530 = torch.aten.add.Scalar %2529, %int9, %int1_2391 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_2392 = torch.constant.int 2 + %2531 = torch.aten.mul.Scalar %2530, %int2_2392 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_2393 = torch.constant.int 1 - %2152 = torch.aten.slice.Tensor %2150, %int2_2391, %int0_2392, %2151, %int1_2393 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2152, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_2394 = torch.constant.int 0 - %2153 = torch.aten.clone %2152, %int0_2394 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2153, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %2532 = torch.aten.add.Tensor %2531, %2528, %int1_2393 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_2394 = torch.constant.int 8 + %2533 = torch.aten.mul.Scalar %2532, %int8_2394 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_2395 = torch.constant.int 1 - %2154 = torch.aten.size.int %2149, %int1_2395 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int + %2534 = torch.aten.add.Tensor %2533, %2522, %int1_2395 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int32_2396 = torch.constant.int 32 - %2155 = torch.aten.mul.int %2154, %int32_2396 : !torch.int, !torch.int -> !torch.int - %int4_2397 = torch.constant.int 4 - %int8_2398 = torch.constant.int 8 - %int128_2399 = torch.constant.int 128 - %2156 = torch.prim.ListConstruct %int4_2397, %2155, %int8_2398, %int128_2399 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2157 = torch.aten._unsafe_view %2153, %2156 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2157, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_2400 = torch.constant.int 0 - %int0_2401 = torch.constant.int 0 - %int9223372036854775807_2402 = torch.constant.int 9223372036854775807 - %int1_2403 = torch.constant.int 1 - %2158 = torch.aten.slice.Tensor %2157, %int0_2400, %int0_2401, %int9223372036854775807_2402, %int1_2403 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2158, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_2404 = torch.constant.int 0 - %int0_2405 = torch.constant.int 0 - %int9223372036854775807_2406 = torch.constant.int 9223372036854775807 - %int1_2407 = torch.constant.int 1 - %2159 = torch.aten.slice.Tensor %2147, %int0_2404, %int0_2405, %int9223372036854775807_2406, %int1_2407 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2159, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_2408 = torch.constant.int 1 - %int0_2409 = torch.constant.int 0 - %int9223372036854775807_2410 = torch.constant.int 9223372036854775807 - %int1_2411 = torch.constant.int 1 - %2160 = torch.aten.slice.Tensor %2159, %int1_2408, %int0_2409, %int9223372036854775807_2410, %int1_2411 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2160, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_2412 = torch.constant.int 2 - %int1_2413 = torch.constant.int 1 - %2161 = torch.aten.select.int %2160, %int2_2412, %int1_2413 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2161, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_2414 = torch.constant.int 2 - %int0_2415 = torch.constant.int 0 - %int1_2416 = torch.constant.int 1 - %2162 = torch.aten.slice.Tensor %2161, %int2_2414, %int0_2415, %2151, %int1_2416 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2162, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_2417 = torch.constant.int 0 - %2163 = torch.aten.clone %2162, %int0_2417 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2163, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_2418 = torch.constant.int 1 - %2164 = torch.aten.size.int %2160, %int1_2418 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_2419 = torch.constant.int 32 - %2165 = torch.aten.mul.int %2164, %int32_2419 : !torch.int, !torch.int -> !torch.int - %int4_2420 = torch.constant.int 4 - %int8_2421 = torch.constant.int 8 - %int128_2422 = torch.constant.int 128 - %2166 = torch.prim.ListConstruct %int4_2420, %2165, %int8_2421, %int128_2422 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2167 = torch.aten._unsafe_view %2163, %2166 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2167, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_2423 = torch.constant.int 0 - %int0_2424 = torch.constant.int 0 - %int9223372036854775807_2425 = torch.constant.int 9223372036854775807 + %2535 = torch.aten.mul.Scalar %2534, %int32_2396 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_2397 = torch.constant.int 1 + %2536 = torch.aten.add.Tensor %2535, %2519, %int1_2397 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_2398 = torch.constant.int 5 + %2537 = torch.prims.convert_element_type %2511, %int5_2398 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_2399 = torch.constant.int 32 + %int2_2400 = torch.constant.int 2 + %int8_2401 = torch.constant.int 8 + %int32_2402 = torch.constant.int 32 + %int128_2403 = torch.constant.int 128 + %2538 = torch.prim.ListConstruct %456, %int32_2399, %int2_2400, %int8_2401, %int32_2402, %int128_2403 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2539 = torch.aten.view %2359, %2538 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2539, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_2404 = torch.constant.int 128 + %2540 = torch.prim.ListConstruct %596, %int128_2404 : (!torch.int, !torch.int) -> !torch.list + %2541 = torch.aten.view %2539, %2540 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2541, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %2542 = torch.prim.ListConstruct %2536 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_2405 = torch.constant.bool false + %2543 = torch.aten.index_put %2541, %2542, %2537, %false_2405 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2543, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_2406 = torch.constant.int 32 + %int2_2407 = torch.constant.int 2 + %int8_2408 = torch.constant.int 8 + %int32_2409 = torch.constant.int 32 + %int128_2410 = torch.constant.int 128 + %2544 = torch.prim.ListConstruct %456, %int32_2406, %int2_2407, %int8_2408, %int32_2409, %int128_2410 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2545 = torch.aten.view %2543, %2544 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2545, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_2411 = torch.constant.int 2097152 + %2546 = torch.prim.ListConstruct %456, %int2097152_2411 : (!torch.int, !torch.int) -> !torch.list + %2547 = torch.aten.view %2545, %2546 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2547, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_2412 = torch.constant.int 32 + %int2_2413 = torch.constant.int 2 + %int8_2414 = torch.constant.int 8 + %int32_2415 = torch.constant.int 32 + %int128_2416 = torch.constant.int 128 + %2548 = torch.prim.ListConstruct %456, %int32_2412, %int2_2413, %int8_2414, %int32_2415, %int128_2416 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2549 = torch.aten.view %2547, %2548 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2549, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_2417 = torch.constant.int 128 + %2550 = torch.prim.ListConstruct %596, %int128_2417 : (!torch.int, !torch.int) -> !torch.list + %2551 = torch.aten.view %2549, %2550 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2551, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_2418 = torch.constant.none + %2552 = torch.aten.clone %133, %none_2418 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2553 = torch.aten.detach %2552 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2554 = torch.aten.detach %2553 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2555 = torch.aten.detach %2554 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_2419 = torch.constant.int 1 + %int1_2420 = torch.constant.int 1 + %int1_2421 = torch.constant.int 1 + %2556 = torch.prim.ListConstruct %int1_2419, %int1_2420, %int1_2421 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2557 = torch.aten.view %2555, %2556 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_2422 = torch.constant.int 32 + %2558 = torch.aten.mul.Scalar %2516, %int32_2422 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int9_2423 = torch.constant.int 9 + %int1_2424 = torch.constant.int 1 + %2559 = torch.aten.add.Scalar %2558, %int9_2423, %int1_2424 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_2425 = torch.constant.int 2 + %2560 = torch.aten.mul.Scalar %2559, %int2_2425 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_2426 = torch.constant.int 1 - %2168 = torch.aten.slice.Tensor %2167, %int0_2423, %int0_2424, %int9223372036854775807_2425, %int1_2426 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2168, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_2427 = torch.constant.int -2 - %2169 = torch.aten.unsqueeze %2158, %int-2_2427 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2169, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %2561 = torch.aten.add.Tensor %2560, %2557, %int1_2426 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_2427 = torch.constant.int 8 + %2562 = torch.aten.mul.Scalar %2561, %int8_2427 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_2428 = torch.constant.int 1 - %2170 = torch.aten.size.int %2157, %int1_2428 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_2429 = torch.constant.int 4 - %int8_2430 = torch.constant.int 8 - %int4_2431 = torch.constant.int 4 - %int128_2432 = torch.constant.int 128 - %2171 = torch.prim.ListConstruct %int4_2429, %2170, %int8_2430, %int4_2431, %int128_2432 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2433 = torch.constant.bool false - %2172 = torch.aten.expand %2169, %2171, %false_2433 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2172, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_2434 = torch.constant.int 0 - %2173 = torch.aten.clone %2172, %int0_2434 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2173, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_2435 = torch.constant.int 4 + %2563 = torch.aten.add.Tensor %2562, %2522, %int1_2428 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_2429 = torch.constant.int 32 + %2564 = torch.aten.mul.Scalar %2563, %int32_2429 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_2430 = torch.constant.int 1 + %2565 = torch.aten.add.Tensor %2564, %2519, %int1_2430 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_2431 = torch.constant.int 5 + %2566 = torch.prims.convert_element_type %2491, %int5_2431 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %2567 = torch.prim.ListConstruct %2565 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_2432 = torch.constant.bool false + %2568 = torch.aten.index_put %2551, %2567, %2566, %false_2432 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2568, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_2433 = torch.constant.int 32 + %int2_2434 = torch.constant.int 2 + %int8_2435 = torch.constant.int 8 %int32_2436 = torch.constant.int 32 %int128_2437 = torch.constant.int 128 - %2174 = torch.prim.ListConstruct %int4_2435, %2170, %int32_2436, %int128_2437 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2175 = torch.aten._unsafe_view %2173, %2174 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2175, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_2438 = torch.constant.int -2 - %2176 = torch.aten.unsqueeze %2168, %int-2_2438 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2176, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_2439 = torch.constant.int 1 - %2177 = torch.aten.size.int %2167, %int1_2439 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_2440 = torch.constant.int 4 - %int8_2441 = torch.constant.int 8 - %int4_2442 = torch.constant.int 4 - %int128_2443 = torch.constant.int 128 - %2178 = torch.prim.ListConstruct %int4_2440, %2177, %int8_2441, %int4_2442, %int128_2443 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2444 = torch.constant.bool false - %2179 = torch.aten.expand %2176, %2178, %false_2444 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2179, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_2445 = torch.constant.int 0 - %2180 = torch.aten.clone %2179, %int0_2445 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2180, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_2446 = torch.constant.int 4 - %int32_2447 = torch.constant.int 32 - %int128_2448 = torch.constant.int 128 - %2181 = torch.prim.ListConstruct %int4_2446, %2177, %int32_2447, %int128_2448 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2182 = torch.aten._unsafe_view %2180, %2181 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2182, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_2449 = torch.constant.int 1 - %int2_2450 = torch.constant.int 2 - %2183 = torch.aten.transpose.int %2063, %int1_2449, %int2_2450 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_2451 = torch.constant.int 1 - %int2_2452 = torch.constant.int 2 - %2184 = torch.aten.transpose.int %2175, %int1_2451, %int2_2452 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2184, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_2453 = torch.constant.int 1 - %int2_2454 = torch.constant.int 2 - %2185 = torch.aten.transpose.int %2182, %int1_2453, %int2_2454 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2185, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_2455 = torch.constant.float 0.000000e+00 - %false_2456 = torch.constant.bool false - %none_2457 = torch.constant.none - %2186:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2183, %2184, %2185, %float0.000000e00_2455, %false_2456, %368, %none_2457) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_2458 = torch.constant.int 1 - %int2_2459 = torch.constant.int 2 - %2187 = torch.aten.transpose.int %2186#0, %int1_2458, %int2_2459 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %2569 = torch.prim.ListConstruct %456, %int32_2433, %int2_2434, %int8_2435, %int32_2436, %int128_2437 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2570 = torch.aten.view %2568, %2569 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2570, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_2438 = torch.constant.int 2097152 + %2571 = torch.prim.ListConstruct %456, %int2097152_2438 : (!torch.int, !torch.int) -> !torch.list + %2572 = torch.aten.view %2570, %2571 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2572, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_2439 = torch.constant.none + %2573 = torch.aten.clone %134, %none_2439 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2574 = torch.aten.detach %2573 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2575 = torch.aten.detach %2574 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2576 = torch.aten.detach %2575 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_2440 = torch.constant.none + %2577 = torch.aten.clone %135, %none_2440 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2578 = torch.aten.detach %2577 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2579 = torch.aten.detach %2578 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2580 = torch.aten.detach %2579 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_2441 = torch.constant.none + %2581 = torch.aten.clone %136, %none_2441 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2582 = torch.aten.detach %2581 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2583 = torch.aten.detach %2582 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2584 = torch.aten.detach %2583 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_2442 = torch.constant.int 32 + %int2_2443 = torch.constant.int 2 + %int8_2444 = torch.constant.int 8 + %int32_2445 = torch.constant.int 32 + %int128_2446 = torch.constant.int 128 + %2585 = torch.prim.ListConstruct %456, %int32_2442, %int2_2443, %int8_2444, %int32_2445, %int128_2446 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2586 = torch.aten.view %2572, %2585 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2586, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %2587 = torch_c.to_builtin_tensor %2586 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %2588 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_2447 = tensor.cast %2588 : tensor<4x?xi64> to tensor + %2589 = torch_c.to_builtin_tensor %2576 : !torch.vtensor<[],si64> -> tensor + %2590 = torch_c.to_builtin_tensor %2580 : !torch.vtensor<[],si64> -> tensor + %2591 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%2587, %cast_2447, %2589, %2590) : (tensor, tensor, tensor, tensor) -> tensor + %cast_2448 = tensor.cast %2591 : tensor to tensor<4x?x8x32x128xf16> + %2592 = torch_c.from_builtin_tensor %cast_2448 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %2592, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %2593 = torch_c.to_builtin_tensor %2586 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %2594 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_2449 = tensor.cast %2594 : tensor<4x?xi64> to tensor + %2595 = torch_c.to_builtin_tensor %2576 : !torch.vtensor<[],si64> -> tensor + %2596 = torch_c.to_builtin_tensor %2584 : !torch.vtensor<[],si64> -> tensor + %2597 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%2593, %cast_2449, %2595, %2596) : (tensor, tensor, tensor, tensor) -> tensor + %cast_2450 = tensor.cast %2597 : tensor to tensor<4x?x8x32x128xf16> + %2598 = torch_c.from_builtin_tensor %cast_2450 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %2598, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_2451 = torch.constant.int 2 + %int3_2452 = torch.constant.int 3 + %2599 = torch.aten.transpose.int %2592, %int2_2451, %int3_2452 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2599, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_2453 = torch.constant.int 0 + %2600 = torch.aten.clone %2599, %int0_2453 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2600, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_2454 = torch.constant.int 4 + %int8_2455 = torch.constant.int 8 + %int128_2456 = torch.constant.int 128 + %2601 = torch.prim.ListConstruct %int4_2454, %457, %int8_2455, %int128_2456 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2602 = torch.aten._unsafe_view %2600, %2601 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2602, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_2457 = torch.constant.int 2 + %int3_2458 = torch.constant.int 3 + %2603 = torch.aten.transpose.int %2598, %int2_2457, %int3_2458 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2603, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_2459 = torch.constant.int 0 + %2604 = torch.aten.clone %2603, %int0_2459 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2604, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int4_2460 = torch.constant.int 4 - %int1_2461 = torch.constant.int 1 - %int4096_2462 = torch.constant.int 4096 - %2188 = torch.prim.ListConstruct %int4_2460, %int1_2461, %int4096_2462 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2189 = torch.aten.view %2187, %2188 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int8_2461 = torch.constant.int 8 + %int128_2462 = torch.constant.int 128 + %2605 = torch.prim.ListConstruct %int4_2460, %457, %int8_2461, %int128_2462 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2606 = torch.aten._unsafe_view %2604, %2605 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2606, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> %int-2_2463 = torch.constant.int -2 - %int-1_2464 = torch.constant.int -1 - %2190 = torch.aten.transpose.int %95, %int-2_2463, %int-1_2464 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_2465 = torch.constant.int 4 - %int4096_2466 = torch.constant.int 4096 - %2191 = torch.prim.ListConstruct %int4_2465, %int4096_2466 : (!torch.int, !torch.int) -> !torch.list - %2192 = torch.aten.view %2189, %2191 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2193 = torch.aten.mm %2192, %2190 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_2467 = torch.constant.int 4 - %int1_2468 = torch.constant.int 1 - %int4096_2469 = torch.constant.int 4096 - %2194 = torch.prim.ListConstruct %int4_2467, %int1_2468, %int4096_2469 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2195 = torch.aten.view %2193, %2194 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_2470 = torch.constant.int 1 - %2196 = torch.aten.add.Tensor %2023, %2195, %int1_2470 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_2471 = torch.constant.int 6 - %2197 = torch.prims.convert_element_type %2196, %int6_2471 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_2472 = torch.constant.int 2 - %2198 = torch.aten.pow.Tensor_Scalar %2197, %int2_2472 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_2473 = torch.constant.int -1 - %2199 = torch.prim.ListConstruct %int-1_2473 : (!torch.int) -> !torch.list - %true_2474 = torch.constant.bool true - %none_2475 = torch.constant.none - %2200 = torch.aten.mean.dim %2198, %2199, %true_2474, %none_2475 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_2476 = torch.constant.float 9.9999997473787516E-6 - %int1_2477 = torch.constant.int 1 - %2201 = torch.aten.add.Scalar %2200, %float9.999990e-06_2476, %int1_2477 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %2202 = torch.aten.rsqrt %2201 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %2203 = torch.aten.mul.Tensor %2197, %2202 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_2478 = torch.constant.int 5 - %2204 = torch.prims.convert_element_type %2203, %int5_2478 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %2205 = torch.aten.mul.Tensor %96, %2204 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_2479 = torch.constant.int 5 - %2206 = torch.prims.convert_element_type %2205, %int5_2479 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_2480 = torch.constant.int -2 - %int-1_2481 = torch.constant.int -1 - %2207 = torch.aten.transpose.int %97, %int-2_2480, %int-1_2481 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_2482 = torch.constant.int 4 - %int4096_2483 = torch.constant.int 4096 - %2208 = torch.prim.ListConstruct %int4_2482, %int4096_2483 : (!torch.int, !torch.int) -> !torch.list - %2209 = torch.aten.view %2206, %2208 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2210 = torch.aten.mm %2209, %2207 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_2484 = torch.constant.int 4 + %2607 = torch.aten.unsqueeze %2602, %int-2_2463 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2607, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_2464 = torch.constant.int 4 + %int8_2465 = torch.constant.int 8 + %int4_2466 = torch.constant.int 4 + %int128_2467 = torch.constant.int 128 + %2608 = torch.prim.ListConstruct %int4_2464, %457, %int8_2465, %int4_2466, %int128_2467 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_2468 = torch.constant.bool false + %2609 = torch.aten.expand %2607, %2608, %false_2468 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2609, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_2469 = torch.constant.int 0 + %2610 = torch.aten.clone %2609, %int0_2469 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2610, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_2470 = torch.constant.int 4 + %int32_2471 = torch.constant.int 32 + %int128_2472 = torch.constant.int 128 + %2611 = torch.prim.ListConstruct %int4_2470, %457, %int32_2471, %int128_2472 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2612 = torch.aten._unsafe_view %2610, %2611 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2612, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_2473 = torch.constant.int -2 + %2613 = torch.aten.unsqueeze %2606, %int-2_2473 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2613, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_2474 = torch.constant.int 4 + %int8_2475 = torch.constant.int 8 + %int4_2476 = torch.constant.int 4 + %int128_2477 = torch.constant.int 128 + %2614 = torch.prim.ListConstruct %int4_2474, %457, %int8_2475, %int4_2476, %int128_2477 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_2478 = torch.constant.bool false + %2615 = torch.aten.expand %2613, %2614, %false_2478 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2615, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_2479 = torch.constant.int 0 + %2616 = torch.aten.clone %2615, %int0_2479 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2616, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_2480 = torch.constant.int 4 + %int32_2481 = torch.constant.int 32 + %int128_2482 = torch.constant.int 128 + %2617 = torch.prim.ListConstruct %int4_2480, %457, %int32_2481, %int128_2482 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2618 = torch.aten._unsafe_view %2616, %2617 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2618, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_2483 = torch.constant.int 1 + %int2_2484 = torch.constant.int 2 + %2619 = torch.aten.transpose.int %2501, %int1_2483, %int2_2484 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_2485 = torch.constant.int 1 - %int14336_2486 = torch.constant.int 14336 - %2211 = torch.prim.ListConstruct %int4_2484, %int1_2485, %int14336_2486 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2212 = torch.aten.view %2210, %2211 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %2213 = torch.aten.silu %2212 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_2487 = torch.constant.int -2 - %int-1_2488 = torch.constant.int -1 - %2214 = torch.aten.transpose.int %98, %int-2_2487, %int-1_2488 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_2489 = torch.constant.int 4 - %int4096_2490 = torch.constant.int 4096 - %2215 = torch.prim.ListConstruct %int4_2489, %int4096_2490 : (!torch.int, !torch.int) -> !torch.list - %2216 = torch.aten.view %2206, %2215 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2217 = torch.aten.mm %2216, %2214 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_2491 = torch.constant.int 4 + %int2_2486 = torch.constant.int 2 + %2620 = torch.aten.transpose.int %2612, %int1_2485, %int2_2486 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2620, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_2487 = torch.constant.int 1 + %int2_2488 = torch.constant.int 2 + %2621 = torch.aten.transpose.int %2618, %int1_2487, %int2_2488 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2621, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_2489 = torch.constant.float 0.000000e+00 + %false_2490 = torch.constant.bool false + %none_2491 = torch.constant.none + %2622:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2619, %2620, %2621, %float0.000000e00_2489, %false_2490, %470, %none_2491) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) %int1_2492 = torch.constant.int 1 - %int14336_2493 = torch.constant.int 14336 - %2218 = torch.prim.ListConstruct %int4_2491, %int1_2492, %int14336_2493 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2219 = torch.aten.view %2217, %2218 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %2220 = torch.aten.mul.Tensor %2213, %2219 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_2494 = torch.constant.int -2 - %int-1_2495 = torch.constant.int -1 - %2221 = torch.aten.transpose.int %99, %int-2_2494, %int-1_2495 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_2496 = torch.constant.int 4 - %int14336_2497 = torch.constant.int 14336 - %2222 = torch.prim.ListConstruct %int4_2496, %int14336_2497 : (!torch.int, !torch.int) -> !torch.list - %2223 = torch.aten.view %2220, %2222 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %2224 = torch.aten.mm %2223, %2221 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_2498 = torch.constant.int 4 - %int1_2499 = torch.constant.int 1 - %int4096_2500 = torch.constant.int 4096 - %2225 = torch.prim.ListConstruct %int4_2498, %int1_2499, %int4096_2500 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2226 = torch.aten.view %2224, %2225 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_2501 = torch.constant.int 1 - %2227 = torch.aten.add.Tensor %2196, %2226, %int1_2501 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_2502 = torch.constant.int 6 - %2228 = torch.prims.convert_element_type %2227, %int6_2502 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_2503 = torch.constant.int 2 - %2229 = torch.aten.pow.Tensor_Scalar %2228, %int2_2503 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_2504 = torch.constant.int -1 - %2230 = torch.prim.ListConstruct %int-1_2504 : (!torch.int) -> !torch.list - %true_2505 = torch.constant.bool true - %none_2506 = torch.constant.none - %2231 = torch.aten.mean.dim %2229, %2230, %true_2505, %none_2506 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_2507 = torch.constant.float 9.9999997473787516E-6 - %int1_2508 = torch.constant.int 1 - %2232 = torch.aten.add.Scalar %2231, %float9.999990e-06_2507, %int1_2508 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %2233 = torch.aten.rsqrt %2232 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %2234 = torch.aten.mul.Tensor %2228, %2233 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_2509 = torch.constant.int 5 - %2235 = torch.prims.convert_element_type %2234, %int5_2509 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %2236 = torch.aten.mul.Tensor %100, %2235 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_2510 = torch.constant.int 5 - %2237 = torch.prims.convert_element_type %2236, %int5_2510 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_2511 = torch.constant.int -2 - %int-1_2512 = torch.constant.int -1 - %2238 = torch.aten.transpose.int %101, %int-2_2511, %int-1_2512 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_2513 = torch.constant.int 4 - %int4096_2514 = torch.constant.int 4096 - %2239 = torch.prim.ListConstruct %int4_2513, %int4096_2514 : (!torch.int, !torch.int) -> !torch.list - %2240 = torch.aten.view %2237, %2239 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2241 = torch.aten.mm %2240, %2238 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_2515 = torch.constant.int 4 - %int1_2516 = torch.constant.int 1 - %int4096_2517 = torch.constant.int 4096 - %2242 = torch.prim.ListConstruct %int4_2515, %int1_2516, %int4096_2517 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2243 = torch.aten.view %2241, %2242 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_2518 = torch.constant.int -2 - %int-1_2519 = torch.constant.int -1 - %2244 = torch.aten.transpose.int %102, %int-2_2518, %int-1_2519 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int2_2493 = torch.constant.int 2 + %2623 = torch.aten.transpose.int %2622#0, %int1_2492, %int2_2493 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_2494 = torch.constant.int 4 + %int1_2495 = torch.constant.int 1 + %int4096_2496 = torch.constant.int 4096 + %2624 = torch.prim.ListConstruct %int4_2494, %int1_2495, %int4096_2496 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2625 = torch.aten.view %2623, %2624 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_2497 = torch.constant.int -2 + %int-1_2498 = torch.constant.int -1 + %2626 = torch.aten.transpose.int %137, %int-2_2497, %int-1_2498 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2499 = torch.constant.int 5 + %2627 = torch.prims.convert_element_type %2626, %int5_2499 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_2500 = torch.constant.int 4 + %int4096_2501 = torch.constant.int 4096 + %2628 = torch.prim.ListConstruct %int4_2500, %int4096_2501 : (!torch.int, !torch.int) -> !torch.list + %2629 = torch.aten.view %2625, %2628 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2630 = torch.aten.mm %2629, %2627 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_2502 = torch.constant.int 4 + %int1_2503 = torch.constant.int 1 + %int4096_2504 = torch.constant.int 4096 + %2631 = torch.prim.ListConstruct %int4_2502, %int1_2503, %int4096_2504 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2632 = torch.aten.view %2630, %2631 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_2505 = torch.constant.int 1 + %2633 = torch.aten.add.Tensor %2454, %2632, %int1_2505 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_2506 = torch.constant.int 6 + %2634 = torch.prims.convert_element_type %2633, %int6_2506 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_2507 = torch.constant.int 2 + %2635 = torch.aten.pow.Tensor_Scalar %2634, %int2_2507 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_2508 = torch.constant.int -1 + %2636 = torch.prim.ListConstruct %int-1_2508 : (!torch.int) -> !torch.list + %true_2509 = torch.constant.bool true + %none_2510 = torch.constant.none + %2637 = torch.aten.mean.dim %2635, %2636, %true_2509, %none_2510 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_2511 = torch.constant.float 9.9999997473787516E-6 + %int1_2512 = torch.constant.int 1 + %2638 = torch.aten.add.Scalar %2637, %float9.999990e-06_2511, %int1_2512 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %2639 = torch.aten.rsqrt %2638 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %2640 = torch.aten.mul.Tensor %2634, %2639 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_2513 = torch.constant.int 5 + %2641 = torch.prims.convert_element_type %2640, %int5_2513 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %2642 = torch.aten.mul.Tensor %138, %2641 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_2514 = torch.constant.int 5 + %2643 = torch.prims.convert_element_type %2642, %int5_2514 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_2515 = torch.constant.int -2 + %int-1_2516 = torch.constant.int -1 + %2644 = torch.aten.transpose.int %139, %int-2_2515, %int-1_2516 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_2517 = torch.constant.int 5 + %2645 = torch.prims.convert_element_type %2644, %int5_2517 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_2518 = torch.constant.int 4 + %int4096_2519 = torch.constant.int 4096 + %2646 = torch.prim.ListConstruct %int4_2518, %int4096_2519 : (!torch.int, !torch.int) -> !torch.list + %2647 = torch.aten.view %2643, %2646 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2648 = torch.aten.mm %2647, %2645 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> %int4_2520 = torch.constant.int 4 - %int4096_2521 = torch.constant.int 4096 - %2245 = torch.prim.ListConstruct %int4_2520, %int4096_2521 : (!torch.int, !torch.int) -> !torch.list - %2246 = torch.aten.view %2237, %2245 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2247 = torch.aten.mm %2246, %2244 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_2522 = torch.constant.int 4 - %int1_2523 = torch.constant.int 1 - %int1024_2524 = torch.constant.int 1024 - %2248 = torch.prim.ListConstruct %int4_2522, %int1_2523, %int1024_2524 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2249 = torch.aten.view %2247, %2248 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_2525 = torch.constant.int -2 - %int-1_2526 = torch.constant.int -1 - %2250 = torch.aten.transpose.int %103, %int-2_2525, %int-1_2526 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_2527 = torch.constant.int 4 - %int4096_2528 = torch.constant.int 4096 - %2251 = torch.prim.ListConstruct %int4_2527, %int4096_2528 : (!torch.int, !torch.int) -> !torch.list - %2252 = torch.aten.view %2237, %2251 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2253 = torch.aten.mm %2252, %2250 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_2529 = torch.constant.int 4 - %int1_2530 = torch.constant.int 1 - %int1024_2531 = torch.constant.int 1024 - %2254 = torch.prim.ListConstruct %int4_2529, %int1_2530, %int1024_2531 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2255 = torch.aten.view %2253, %2254 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_2532 = torch.constant.int 4 - %int1_2533 = torch.constant.int 1 - %int32_2534 = torch.constant.int 32 - %int128_2535 = torch.constant.int 128 - %2256 = torch.prim.ListConstruct %int4_2532, %int1_2533, %int32_2534, %int128_2535 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2257 = torch.aten.view %2243, %2256 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int1_2521 = torch.constant.int 1 + %int14336_2522 = torch.constant.int 14336 + %2649 = torch.prim.ListConstruct %int4_2520, %int1_2521, %int14336_2522 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2650 = torch.aten.view %2648, %2649 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %2651 = torch.aten.silu %2650 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_2523 = torch.constant.int -2 + %int-1_2524 = torch.constant.int -1 + %2652 = torch.aten.transpose.int %140, %int-2_2523, %int-1_2524 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_2525 = torch.constant.int 5 + %2653 = torch.prims.convert_element_type %2652, %int5_2525 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_2526 = torch.constant.int 4 + %int4096_2527 = torch.constant.int 4096 + %2654 = torch.prim.ListConstruct %int4_2526, %int4096_2527 : (!torch.int, !torch.int) -> !torch.list + %2655 = torch.aten.view %2643, %2654 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2656 = torch.aten.mm %2655, %2653 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_2528 = torch.constant.int 4 + %int1_2529 = torch.constant.int 1 + %int14336_2530 = torch.constant.int 14336 + %2657 = torch.prim.ListConstruct %int4_2528, %int1_2529, %int14336_2530 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2658 = torch.aten.view %2656, %2657 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %2659 = torch.aten.mul.Tensor %2651, %2658 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_2531 = torch.constant.int -2 + %int-1_2532 = torch.constant.int -1 + %2660 = torch.aten.transpose.int %141, %int-2_2531, %int-1_2532 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_2533 = torch.constant.int 5 + %2661 = torch.prims.convert_element_type %2660, %int5_2533 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_2534 = torch.constant.int 4 + %int14336_2535 = torch.constant.int 14336 + %2662 = torch.prim.ListConstruct %int4_2534, %int14336_2535 : (!torch.int, !torch.int) -> !torch.list + %2663 = torch.aten.view %2659, %2662 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %2664 = torch.aten.mm %2663, %2661 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_2536 = torch.constant.int 4 %int1_2537 = torch.constant.int 1 - %int8_2538 = torch.constant.int 8 - %int128_2539 = torch.constant.int 128 - %2258 = torch.prim.ListConstruct %int4_2536, %int1_2537, %int8_2538, %int128_2539 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2259 = torch.aten.view %2249, %2258 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_2540 = torch.constant.int 4 - %int1_2541 = torch.constant.int 1 - %int8_2542 = torch.constant.int 8 - %int128_2543 = torch.constant.int 128 - %2260 = torch.prim.ListConstruct %int4_2540, %int1_2541, %int8_2542, %int128_2543 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2261 = torch.aten.view %2255, %2260 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_2544 = torch.constant.int 6 - %2262 = torch.prims.convert_element_type %2257, %int6_2544 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %2263 = torch_c.to_builtin_tensor %2262 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %2264 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %2265 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%2263, %2264) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %2266 = torch_c.from_builtin_tensor %2265 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_2545 = torch.constant.int 5 - %2267 = torch.prims.convert_element_type %2266, %int5_2545 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_2546 = torch.constant.int 6 - %2268 = torch.prims.convert_element_type %2259, %int6_2546 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %2269 = torch_c.to_builtin_tensor %2268 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %2270 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %2271 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%2269, %2270) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %2272 = torch_c.from_builtin_tensor %2271 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> + %int4096_2538 = torch.constant.int 4096 + %2665 = torch.prim.ListConstruct %int4_2536, %int1_2537, %int4096_2538 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2666 = torch.aten.view %2664, %2665 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_2539 = torch.constant.int 1 + %2667 = torch.aten.add.Tensor %2633, %2666, %int1_2539 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_2540 = torch.constant.int 6 + %2668 = torch.prims.convert_element_type %2667, %int6_2540 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_2541 = torch.constant.int 2 + %2669 = torch.aten.pow.Tensor_Scalar %2668, %int2_2541 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_2542 = torch.constant.int -1 + %2670 = torch.prim.ListConstruct %int-1_2542 : (!torch.int) -> !torch.list + %true_2543 = torch.constant.bool true + %none_2544 = torch.constant.none + %2671 = torch.aten.mean.dim %2669, %2670, %true_2543, %none_2544 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_2545 = torch.constant.float 9.9999997473787516E-6 + %int1_2546 = torch.constant.int 1 + %2672 = torch.aten.add.Scalar %2671, %float9.999990e-06_2545, %int1_2546 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %2673 = torch.aten.rsqrt %2672 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %2674 = torch.aten.mul.Tensor %2668, %2673 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> %int5_2547 = torch.constant.int 5 - %2273 = torch.prims.convert_element_type %2272, %int5_2547 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_2548 = torch.constant.int 32 - %2274 = torch.aten.floor_divide.Scalar %arg2, %int32_2548 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_2549 = torch.constant.int 1 - %2275 = torch.aten.unsqueeze %2274, %int1_2549 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2550 = torch.constant.int 1 - %false_2551 = torch.constant.bool false - %2276 = torch.aten.gather %arg3, %int1_2550, %2275, %false_2551 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_2552 = torch.constant.int 32 - %2277 = torch.aten.remainder.Scalar %arg2, %int32_2552 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_2553 = torch.constant.int 1 - %2278 = torch.aten.unsqueeze %2277, %int1_2553 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_2554 = torch.constant.none - %2279 = torch.aten.clone %104, %none_2554 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_2555 = torch.constant.int 0 - %2280 = torch.aten.unsqueeze %2279, %int0_2555 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_2556 = torch.constant.int 4 - %int1_2557 = torch.constant.int 1 - %2281 = torch.prim.ListConstruct %int4_2556, %int1_2557 : (!torch.int, !torch.int) -> !torch.list - %int1_2558 = torch.constant.int 1 - %int1_2559 = torch.constant.int 1 - %2282 = torch.prim.ListConstruct %int1_2558, %int1_2559 : (!torch.int, !torch.int) -> !torch.list + %2675 = torch.prims.convert_element_type %2674, %int5_2547 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %2676 = torch.aten.mul.Tensor %142, %2675 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_2548 = torch.constant.int 5 + %2677 = torch.prims.convert_element_type %2676, %int5_2548 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_2549 = torch.constant.int -2 + %int-1_2550 = torch.constant.int -1 + %2678 = torch.aten.transpose.int %143, %int-2_2549, %int-1_2550 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2551 = torch.constant.int 5 + %2679 = torch.prims.convert_element_type %2678, %int5_2551 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_2552 = torch.constant.int 4 + %int4096_2553 = torch.constant.int 4096 + %2680 = torch.prim.ListConstruct %int4_2552, %int4096_2553 : (!torch.int, !torch.int) -> !torch.list + %2681 = torch.aten.view %2677, %2680 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2682 = torch.aten.mm %2681, %2679 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_2554 = torch.constant.int 4 + %int1_2555 = torch.constant.int 1 + %int4096_2556 = torch.constant.int 4096 + %2683 = torch.prim.ListConstruct %int4_2554, %int1_2555, %int4096_2556 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2684 = torch.aten.view %2682, %2683 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_2557 = torch.constant.int -2 + %int-1_2558 = torch.constant.int -1 + %2685 = torch.aten.transpose.int %144, %int-2_2557, %int-1_2558 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_2559 = torch.constant.int 5 + %2686 = torch.prims.convert_element_type %2685, %int5_2559 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> %int4_2560 = torch.constant.int 4 - %int0_2561 = torch.constant.int 0 - %cpu_2562 = torch.constant.device "cpu" - %false_2563 = torch.constant.bool false - %2283 = torch.aten.empty_strided %2281, %2282, %int4_2560, %int0_2561, %cpu_2562, %false_2563 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int9 = torch.constant.int 9 - %2284 = torch.aten.fill.Scalar %2283, %int9 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_2564 = torch.constant.int 4 - %int1_2565 = torch.constant.int 1 - %2285 = torch.prim.ListConstruct %int4_2564, %int1_2565 : (!torch.int, !torch.int) -> !torch.list - %2286 = torch.aten.repeat %2280, %2285 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_2566 = torch.constant.int 32 - %2287 = torch.aten.mul.Scalar %2276, %int32_2566 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2567 = torch.constant.int 1 - %2288 = torch.aten.add.Tensor %2287, %2284, %int1_2567 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_2568 = torch.constant.int 2 - %2289 = torch.aten.mul.Scalar %2288, %int2_2568 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2569 = torch.constant.int 1 - %2290 = torch.aten.add.Tensor %2289, %2286, %int1_2569 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_2570 = torch.constant.int 32 - %2291 = torch.aten.mul.Scalar %2290, %int32_2570 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int4096_2561 = torch.constant.int 4096 + %2687 = torch.prim.ListConstruct %int4_2560, %int4096_2561 : (!torch.int, !torch.int) -> !torch.list + %2688 = torch.aten.view %2677, %2687 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2689 = torch.aten.mm %2688, %2686 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_2562 = torch.constant.int 4 + %int1_2563 = torch.constant.int 1 + %int1024_2564 = torch.constant.int 1024 + %2690 = torch.prim.ListConstruct %int4_2562, %int1_2563, %int1024_2564 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2691 = torch.aten.view %2689, %2690 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_2565 = torch.constant.int -2 + %int-1_2566 = torch.constant.int -1 + %2692 = torch.aten.transpose.int %145, %int-2_2565, %int-1_2566 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_2567 = torch.constant.int 5 + %2693 = torch.prims.convert_element_type %2692, %int5_2567 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_2568 = torch.constant.int 4 + %int4096_2569 = torch.constant.int 4096 + %2694 = torch.prim.ListConstruct %int4_2568, %int4096_2569 : (!torch.int, !torch.int) -> !torch.list + %2695 = torch.aten.view %2677, %2694 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2696 = torch.aten.mm %2695, %2693 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_2570 = torch.constant.int 4 %int1_2571 = torch.constant.int 1 - %2292 = torch.aten.add.Tensor %2291, %2278, %int1_2571 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_2572 = torch.constant.int 32 - %int2_2573 = torch.constant.int 2 - %int32_2574 = torch.constant.int 32 - %int8_2575 = torch.constant.int 8 + %int1024_2572 = torch.constant.int 1024 + %2697 = torch.prim.ListConstruct %int4_2570, %int1_2571, %int1024_2572 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2698 = torch.aten.view %2696, %2697 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_2573 = torch.constant.int 4 + %int1_2574 = torch.constant.int 1 + %int32_2575 = torch.constant.int 32 %int128_2576 = torch.constant.int 128 - %2293 = torch.prim.ListConstruct %437, %int32_2572, %int2_2573, %int32_2574, %int8_2575, %int128_2576 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2294 = torch.aten.view %2130, %2293 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2294, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2577 = torch.constant.int 32 - %2295 = torch.aten.mul.int %437, %int32_2577 : !torch.int, !torch.int -> !torch.int - %int2_2578 = torch.constant.int 2 - %2296 = torch.aten.mul.int %2295, %int2_2578 : !torch.int, !torch.int -> !torch.int - %int32_2579 = torch.constant.int 32 - %2297 = torch.aten.mul.int %2296, %int32_2579 : !torch.int, !torch.int -> !torch.int - %int8_2580 = torch.constant.int 8 - %int128_2581 = torch.constant.int 128 - %2298 = torch.prim.ListConstruct %2297, %int8_2580, %int128_2581 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2299 = torch.aten.view %2294, %2298 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2299, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %2300 = torch.prim.ListConstruct %2292 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_2582 = torch.constant.bool false - %2301 = torch.aten.index_put %2299, %2300, %2273, %false_2582 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2301, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_2583 = torch.constant.int 32 - %int2_2584 = torch.constant.int 2 - %int32_2585 = torch.constant.int 32 - %int8_2586 = torch.constant.int 8 - %int128_2587 = torch.constant.int 128 - %2302 = torch.prim.ListConstruct %437, %int32_2583, %int2_2584, %int32_2585, %int8_2586, %int128_2587 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2303 = torch.aten.view %2301, %2302 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2303, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2588 = torch.constant.int 2097152 - %2304 = torch.prim.ListConstruct %437, %int2097152_2588 : (!torch.int, !torch.int) -> !torch.list - %2305 = torch.aten.view %2303, %2304 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2305, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_2589 = torch.constant.int 32 - %int2_2590 = torch.constant.int 2 - %int32_2591 = torch.constant.int 32 - %int8_2592 = torch.constant.int 8 - %int128_2593 = torch.constant.int 128 - %2306 = torch.prim.ListConstruct %437, %int32_2589, %int2_2590, %int32_2591, %int8_2592, %int128_2593 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2307 = torch.aten.view %2305, %2306 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2307, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_2594 = torch.constant.int 8 - %int128_2595 = torch.constant.int 128 - %2308 = torch.prim.ListConstruct %2297, %int8_2594, %int128_2595 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2309 = torch.aten.view %2307, %2308 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2309, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_2596 = torch.constant.int 32 - %2310 = torch.aten.floor_divide.Scalar %arg2, %int32_2596 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %2699 = torch.prim.ListConstruct %int4_2573, %int1_2574, %int32_2575, %int128_2576 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2700 = torch.aten.view %2684, %2699 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_2577 = torch.constant.int 4 + %int1_2578 = torch.constant.int 1 + %int8_2579 = torch.constant.int 8 + %int128_2580 = torch.constant.int 128 + %2701 = torch.prim.ListConstruct %int4_2577, %int1_2578, %int8_2579, %int128_2580 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2702 = torch.aten.view %2691, %2701 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_2581 = torch.constant.int 4 + %int1_2582 = torch.constant.int 1 + %int8_2583 = torch.constant.int 8 + %int128_2584 = torch.constant.int 128 + %2703 = torch.prim.ListConstruct %int4_2581, %int1_2582, %int8_2583, %int128_2584 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2704 = torch.aten.view %2698, %2703 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_2585 = torch.constant.int 1 + %int2_2586 = torch.constant.int 2 + %2705 = torch.aten.transpose.int %2700, %int1_2585, %int2_2586 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %2706 = torch.aten.mul.Tensor %2705, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_2587 = torch.constant.int 3 + %int0_2588 = torch.constant.int 0 + %int64_2589 = torch.constant.int 64 + %int1_2590 = torch.constant.int 1 + %2707 = torch.aten.slice.Tensor %2705, %int3_2587, %int0_2588, %int64_2589, %int1_2590 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_2591 = torch.constant.int 3 + %int64_2592 = torch.constant.int 64 + %int9223372036854775807_2593 = torch.constant.int 9223372036854775807 + %int1_2594 = torch.constant.int 1 + %2708 = torch.aten.slice.Tensor %2705, %int3_2591, %int64_2592, %int9223372036854775807_2593, %int1_2594 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %2709 = torch.aten.neg %2708 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %2710 = torch.prim.ListConstruct %2709, %2707 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_2595 = torch.constant.int -1 + %2711 = torch.aten.cat %2710, %int-1_2595 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %2712 = torch.aten.mul.Tensor %2711, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_2596 = torch.constant.int 1 + %2713 = torch.aten.add.Tensor %2706, %2712, %int1_2596 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_2597 = torch.constant.int 1 - %2311 = torch.aten.unsqueeze %2310, %int1_2597 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2598 = torch.constant.int 1 - %false_2599 = torch.constant.bool false - %2312 = torch.aten.gather %arg3, %int1_2598, %2311, %false_2599 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_2600 = torch.constant.int 32 - %2313 = torch.aten.remainder.Scalar %arg2, %int32_2600 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_2601 = torch.constant.int 1 - %2314 = torch.aten.unsqueeze %2313, %int1_2601 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_2602 = torch.constant.none - %2315 = torch.aten.clone %105, %none_2602 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_2603 = torch.constant.int 0 - %2316 = torch.aten.unsqueeze %2315, %int0_2603 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_2604 = torch.constant.int 4 - %int1_2605 = torch.constant.int 1 - %2317 = torch.prim.ListConstruct %int4_2604, %int1_2605 : (!torch.int, !torch.int) -> !torch.list - %int1_2606 = torch.constant.int 1 - %int1_2607 = torch.constant.int 1 - %2318 = torch.prim.ListConstruct %int1_2606, %int1_2607 : (!torch.int, !torch.int) -> !torch.list - %int4_2608 = torch.constant.int 4 - %int0_2609 = torch.constant.int 0 - %cpu_2610 = torch.constant.device "cpu" - %false_2611 = torch.constant.bool false - %2319 = torch.aten.empty_strided %2317, %2318, %int4_2608, %int0_2609, %cpu_2610, %false_2611 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int9_2612 = torch.constant.int 9 - %2320 = torch.aten.fill.Scalar %2319, %int9_2612 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_2613 = torch.constant.int 4 + %int2_2598 = torch.constant.int 2 + %2714 = torch.aten.transpose.int %2713, %int1_2597, %int2_2598 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_2599 = torch.constant.int 1 + %int2_2600 = torch.constant.int 2 + %2715 = torch.aten.transpose.int %2702, %int1_2599, %int2_2600 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %2716 = torch.aten.mul.Tensor %2715, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_2601 = torch.constant.int 3 + %int0_2602 = torch.constant.int 0 + %int64_2603 = torch.constant.int 64 + %int1_2604 = torch.constant.int 1 + %2717 = torch.aten.slice.Tensor %2715, %int3_2601, %int0_2602, %int64_2603, %int1_2604 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_2605 = torch.constant.int 3 + %int64_2606 = torch.constant.int 64 + %int9223372036854775807_2607 = torch.constant.int 9223372036854775807 + %int1_2608 = torch.constant.int 1 + %2718 = torch.aten.slice.Tensor %2715, %int3_2605, %int64_2606, %int9223372036854775807_2607, %int1_2608 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %2719 = torch.aten.neg %2718 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %2720 = torch.prim.ListConstruct %2719, %2717 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_2609 = torch.constant.int -1 + %2721 = torch.aten.cat %2720, %int-1_2609 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %2722 = torch.aten.mul.Tensor %2721, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_2610 = torch.constant.int 1 + %2723 = torch.aten.add.Tensor %2716, %2722, %int1_2610 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_2611 = torch.constant.int 1 + %int2_2612 = torch.constant.int 2 + %2724 = torch.aten.transpose.int %2723, %int1_2611, %int2_2612 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_2613 = torch.constant.int 32 + %2725 = torch.aten.floor_divide.Scalar %arg2, %int32_2613 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> %int1_2614 = torch.constant.int 1 - %2321 = torch.prim.ListConstruct %int4_2613, %int1_2614 : (!torch.int, !torch.int) -> !torch.list - %2322 = torch.aten.repeat %2316, %2321 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_2615 = torch.constant.int 32 - %2323 = torch.aten.mul.Scalar %2312, %int32_2615 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2616 = torch.constant.int 1 - %2324 = torch.aten.add.Tensor %2323, %2320, %int1_2616 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_2617 = torch.constant.int 2 - %2325 = torch.aten.mul.Scalar %2324, %int2_2617 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %2726 = torch.aten.unsqueeze %2725, %int1_2614 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_2615 = torch.constant.int 1 + %false_2616 = torch.constant.bool false + %2727 = torch.aten.gather %arg3, %int1_2615, %2726, %false_2616 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_2617 = torch.constant.int 4 %int1_2618 = torch.constant.int 1 - %2326 = torch.aten.add.Tensor %2325, %2322, %int1_2618 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_2619 = torch.constant.int 32 - %2327 = torch.aten.mul.Scalar %2326, %int32_2619 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2620 = torch.constant.int 1 - %2328 = torch.aten.add.Tensor %2327, %2314, %int1_2620 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %2329 = torch.prim.ListConstruct %2328 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_2621 = torch.constant.bool false - %2330 = torch.aten.index_put %2309, %2329, %2261, %false_2621 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2330, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_2622 = torch.constant.int 32 - %int2_2623 = torch.constant.int 2 - %int32_2624 = torch.constant.int 32 - %int8_2625 = torch.constant.int 8 - %int128_2626 = torch.constant.int 128 - %2331 = torch.prim.ListConstruct %437, %int32_2622, %int2_2623, %int32_2624, %int8_2625, %int128_2626 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2332 = torch.aten.view %2330, %2331 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2332, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2627 = torch.constant.int 2097152 - %2333 = torch.prim.ListConstruct %437, %int2097152_2627 : (!torch.int, !torch.int) -> !torch.list - %2334 = torch.aten.view %2332, %2333 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2334, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_2628 = torch.constant.int 4 - %2335 = torch.prim.ListConstruct %int4_2628, %358 : (!torch.int, !torch.int) -> !torch.list + %int1_2619 = torch.constant.int 1 + %2728 = torch.prim.ListConstruct %int4_2617, %int1_2618, %int1_2619 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2729 = torch.aten.view %2727, %2728 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_2620 = torch.constant.int 32 + %2730 = torch.aten.remainder.Scalar %arg2, %int32_2620 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_2621 = torch.constant.int 4 + %int1_2622 = torch.constant.int 1 + %int1_2623 = torch.constant.int 1 + %2731 = torch.prim.ListConstruct %int4_2621, %int1_2622, %int1_2623 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2732 = torch.aten.view %2730, %2731 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_2624 = torch.constant.int 8 + %none_2625 = torch.constant.none + %none_2626 = torch.constant.none + %cpu_2627 = torch.constant.device "cpu" + %false_2628 = torch.constant.bool false + %2733 = torch.aten.arange %int8_2624, %none_2625, %none_2626, %cpu_2627, %false_2628 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> %int1_2629 = torch.constant.int 1 - %2336 = torch.prim.ListConstruct %358, %int1_2629 : (!torch.int, !torch.int) -> !torch.list - %int4_2630 = torch.constant.int 4 - %int0_2631 = torch.constant.int 0 - %cpu_2632 = torch.constant.device "cpu" - %false_2633 = torch.constant.bool false - %2337 = torch.aten.empty_strided %2335, %2336, %int4_2630, %int0_2631, %cpu_2632, %false_2633 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2337, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int9_2634 = torch.constant.int 9 - %2338 = torch.aten.fill.Scalar %2337, %int9_2634 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2338, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_2635 = torch.constant.int 32 - %2339 = torch.aten.mul.Scalar %arg3, %int32_2635 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2339, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_2636 = torch.constant.int 1 - %2340 = torch.aten.add.Tensor %2339, %2338, %int1_2636 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2340, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_2637 = torch.constant.int 4 - %2341 = torch.aten.mul.int %int4_2637, %358 : !torch.int, !torch.int -> !torch.int - %2342 = torch.prim.ListConstruct %2341 : (!torch.int) -> !torch.list - %2343 = torch.aten.view %2340, %2342 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2343, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_2638 = torch.constant.int 32 - %int2_2639 = torch.constant.int 2 - %int32_2640 = torch.constant.int 32 - %int8_2641 = torch.constant.int 8 - %int128_2642 = torch.constant.int 128 - %2344 = torch.prim.ListConstruct %437, %int32_2638, %int2_2639, %int32_2640, %int8_2641, %int128_2642 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2345 = torch.aten.view %2334, %2344 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2345, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2643 = torch.constant.int 32 - %2346 = torch.aten.mul.int %437, %int32_2643 : !torch.int, !torch.int -> !torch.int - %int2_2644 = torch.constant.int 2 + %int1_2630 = torch.constant.int 1 + %int8_2631 = torch.constant.int 8 + %2734 = torch.prim.ListConstruct %int1_2629, %int1_2630, %int8_2631 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2735 = torch.aten.view %2733, %2734 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_2632 = torch.constant.none + %2736 = torch.aten.clone %146, %none_2632 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2737 = torch.aten.detach %2736 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2738 = torch.aten.detach %2737 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2739 = torch.aten.detach %2738 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_2633 = torch.constant.int 1 + %int1_2634 = torch.constant.int 1 + %int1_2635 = torch.constant.int 1 + %2740 = torch.prim.ListConstruct %int1_2633, %int1_2634, %int1_2635 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2741 = torch.aten.view %2739, %2740 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_2636 = torch.constant.int 32 + %2742 = torch.aten.mul.Scalar %2729, %int32_2636 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int10 = torch.constant.int 10 + %int1_2637 = torch.constant.int 1 + %2743 = torch.aten.add.Scalar %2742, %int10, %int1_2637 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_2638 = torch.constant.int 2 + %2744 = torch.aten.mul.Scalar %2743, %int2_2638 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_2639 = torch.constant.int 1 + %2745 = torch.aten.add.Tensor %2744, %2741, %int1_2639 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_2640 = torch.constant.int 8 + %2746 = torch.aten.mul.Scalar %2745, %int8_2640 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_2641 = torch.constant.int 1 + %2747 = torch.aten.add.Tensor %2746, %2735, %int1_2641 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_2642 = torch.constant.int 32 + %2748 = torch.aten.mul.Scalar %2747, %int32_2642 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_2643 = torch.constant.int 1 + %2749 = torch.aten.add.Tensor %2748, %2732, %int1_2643 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_2644 = torch.constant.int 5 + %2750 = torch.prims.convert_element_type %2724, %int5_2644 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> %int32_2645 = torch.constant.int 32 - %int8_2646 = torch.constant.int 8 - %int128_2647 = torch.constant.int 128 - %2347 = torch.prim.ListConstruct %2346, %int2_2644, %int32_2645, %int8_2646, %int128_2647 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2348 = torch.aten.view %2345, %2347 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %2348, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_2648 = torch.constant.int 0 - %2349 = torch.aten.index_select %2348, %int0_2648, %2343 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %2349, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_2649 = torch.constant.int 4 - %int2_2650 = torch.constant.int 2 - %int32_2651 = torch.constant.int 32 - %int8_2652 = torch.constant.int 8 - %int128_2653 = torch.constant.int 128 - %2350 = torch.prim.ListConstruct %int4_2649, %358, %int2_2650, %int32_2651, %int8_2652, %int128_2653 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2351 = torch.aten.view %2349, %2350 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2351, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_2654 = torch.constant.int 0 - %int0_2655 = torch.constant.int 0 - %int9223372036854775807_2656 = torch.constant.int 9223372036854775807 - %int1_2657 = torch.constant.int 1 - %2352 = torch.aten.slice.Tensor %2351, %int0_2654, %int0_2655, %int9223372036854775807_2656, %int1_2657 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2352, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_2658 = torch.constant.int 1 - %int0_2659 = torch.constant.int 0 - %int9223372036854775807_2660 = torch.constant.int 9223372036854775807 - %int1_2661 = torch.constant.int 1 - %2353 = torch.aten.slice.Tensor %2352, %int1_2658, %int0_2659, %int9223372036854775807_2660, %int1_2661 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2353, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_2662 = torch.constant.int 2 - %int0_2663 = torch.constant.int 0 - %2354 = torch.aten.select.int %2353, %int2_2662, %int0_2663 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2354, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_2664 = torch.constant.int 32 - %2355 = torch.aten.mul.int %358, %int32_2664 : !torch.int, !torch.int -> !torch.int - %int2_2665 = torch.constant.int 2 - %int0_2666 = torch.constant.int 0 + %int2_2646 = torch.constant.int 2 + %int8_2647 = torch.constant.int 8 + %int32_2648 = torch.constant.int 32 + %int128_2649 = torch.constant.int 128 + %2751 = torch.prim.ListConstruct %456, %int32_2645, %int2_2646, %int8_2647, %int32_2648, %int128_2649 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2752 = torch.aten.view %2572, %2751 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2752, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_2650 = torch.constant.int 128 + %2753 = torch.prim.ListConstruct %596, %int128_2650 : (!torch.int, !torch.int) -> !torch.list + %2754 = torch.aten.view %2752, %2753 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2754, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %2755 = torch.prim.ListConstruct %2749 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_2651 = torch.constant.bool false + %2756 = torch.aten.index_put %2754, %2755, %2750, %false_2651 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2756, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_2652 = torch.constant.int 32 + %int2_2653 = torch.constant.int 2 + %int8_2654 = torch.constant.int 8 + %int32_2655 = torch.constant.int 32 + %int128_2656 = torch.constant.int 128 + %2757 = torch.prim.ListConstruct %456, %int32_2652, %int2_2653, %int8_2654, %int32_2655, %int128_2656 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2758 = torch.aten.view %2756, %2757 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2758, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_2657 = torch.constant.int 2097152 + %2759 = torch.prim.ListConstruct %456, %int2097152_2657 : (!torch.int, !torch.int) -> !torch.list + %2760 = torch.aten.view %2758, %2759 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2760, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_2658 = torch.constant.int 32 + %int2_2659 = torch.constant.int 2 + %int8_2660 = torch.constant.int 8 + %int32_2661 = torch.constant.int 32 + %int128_2662 = torch.constant.int 128 + %2761 = torch.prim.ListConstruct %456, %int32_2658, %int2_2659, %int8_2660, %int32_2661, %int128_2662 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2762 = torch.aten.view %2760, %2761 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2762, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_2663 = torch.constant.int 128 + %2763 = torch.prim.ListConstruct %596, %int128_2663 : (!torch.int, !torch.int) -> !torch.list + %2764 = torch.aten.view %2762, %2763 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2764, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_2664 = torch.constant.none + %2765 = torch.aten.clone %147, %none_2664 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2766 = torch.aten.detach %2765 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2767 = torch.aten.detach %2766 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2768 = torch.aten.detach %2767 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_2665 = torch.constant.int 1 + %int1_2666 = torch.constant.int 1 %int1_2667 = torch.constant.int 1 - %2356 = torch.aten.slice.Tensor %2354, %int2_2665, %int0_2666, %2355, %int1_2667 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2356, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_2668 = torch.constant.int 0 - %2357 = torch.aten.clone %2356, %int0_2668 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2357, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_2669 = torch.constant.int 1 - %2358 = torch.aten.size.int %2353, %int1_2669 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_2670 = torch.constant.int 32 - %2359 = torch.aten.mul.int %2358, %int32_2670 : !torch.int, !torch.int -> !torch.int - %int4_2671 = torch.constant.int 4 - %int8_2672 = torch.constant.int 8 - %int128_2673 = torch.constant.int 128 - %2360 = torch.prim.ListConstruct %int4_2671, %2359, %int8_2672, %int128_2673 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2361 = torch.aten._unsafe_view %2357, %2360 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2361, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_2674 = torch.constant.int 0 - %int0_2675 = torch.constant.int 0 - %int9223372036854775807_2676 = torch.constant.int 9223372036854775807 - %int1_2677 = torch.constant.int 1 - %2362 = torch.aten.slice.Tensor %2361, %int0_2674, %int0_2675, %int9223372036854775807_2676, %int1_2677 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2362, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_2678 = torch.constant.int 0 - %int0_2679 = torch.constant.int 0 - %int9223372036854775807_2680 = torch.constant.int 9223372036854775807 - %int1_2681 = torch.constant.int 1 - %2363 = torch.aten.slice.Tensor %2351, %int0_2678, %int0_2679, %int9223372036854775807_2680, %int1_2681 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2363, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_2682 = torch.constant.int 1 - %int0_2683 = torch.constant.int 0 - %int9223372036854775807_2684 = torch.constant.int 9223372036854775807 - %int1_2685 = torch.constant.int 1 - %2364 = torch.aten.slice.Tensor %2363, %int1_2682, %int0_2683, %int9223372036854775807_2684, %int1_2685 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2364, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_2686 = torch.constant.int 2 - %int1_2687 = torch.constant.int 1 - %2365 = torch.aten.select.int %2364, %int2_2686, %int1_2687 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2365, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_2688 = torch.constant.int 2 - %int0_2689 = torch.constant.int 0 - %int1_2690 = torch.constant.int 1 - %2366 = torch.aten.slice.Tensor %2365, %int2_2688, %int0_2689, %2355, %int1_2690 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2366, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_2691 = torch.constant.int 0 - %2367 = torch.aten.clone %2366, %int0_2691 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2367, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_2692 = torch.constant.int 1 - %2368 = torch.aten.size.int %2364, %int1_2692 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_2693 = torch.constant.int 32 - %2369 = torch.aten.mul.int %2368, %int32_2693 : !torch.int, !torch.int -> !torch.int - %int4_2694 = torch.constant.int 4 - %int8_2695 = torch.constant.int 8 - %int128_2696 = torch.constant.int 128 - %2370 = torch.prim.ListConstruct %int4_2694, %2369, %int8_2695, %int128_2696 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2371 = torch.aten._unsafe_view %2367, %2370 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2371, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_2697 = torch.constant.int 0 - %int0_2698 = torch.constant.int 0 - %int9223372036854775807_2699 = torch.constant.int 9223372036854775807 - %int1_2700 = torch.constant.int 1 - %2372 = torch.aten.slice.Tensor %2371, %int0_2697, %int0_2698, %int9223372036854775807_2699, %int1_2700 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2372, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_2701 = torch.constant.int -2 - %2373 = torch.aten.unsqueeze %2362, %int-2_2701 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2373, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_2702 = torch.constant.int 1 - %2374 = torch.aten.size.int %2361, %int1_2702 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_2703 = torch.constant.int 4 - %int8_2704 = torch.constant.int 8 - %int4_2705 = torch.constant.int 4 - %int128_2706 = torch.constant.int 128 - %2375 = torch.prim.ListConstruct %int4_2703, %2374, %int8_2704, %int4_2705, %int128_2706 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2707 = torch.constant.bool false - %2376 = torch.aten.expand %2373, %2375, %false_2707 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2376, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_2708 = torch.constant.int 0 - %2377 = torch.aten.clone %2376, %int0_2708 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2377, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_2709 = torch.constant.int 4 - %int32_2710 = torch.constant.int 32 - %int128_2711 = torch.constant.int 128 - %2378 = torch.prim.ListConstruct %int4_2709, %2374, %int32_2710, %int128_2711 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2379 = torch.aten._unsafe_view %2377, %2378 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2379, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_2712 = torch.constant.int -2 - %2380 = torch.aten.unsqueeze %2372, %int-2_2712 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2380, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_2713 = torch.constant.int 1 - %2381 = torch.aten.size.int %2371, %int1_2713 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_2714 = torch.constant.int 4 - %int8_2715 = torch.constant.int 8 + %2769 = torch.prim.ListConstruct %int1_2665, %int1_2666, %int1_2667 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2770 = torch.aten.view %2768, %2769 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_2668 = torch.constant.int 32 + %2771 = torch.aten.mul.Scalar %2729, %int32_2668 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int10_2669 = torch.constant.int 10 + %int1_2670 = torch.constant.int 1 + %2772 = torch.aten.add.Scalar %2771, %int10_2669, %int1_2670 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_2671 = torch.constant.int 2 + %2773 = torch.aten.mul.Scalar %2772, %int2_2671 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_2672 = torch.constant.int 1 + %2774 = torch.aten.add.Tensor %2773, %2770, %int1_2672 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_2673 = torch.constant.int 8 + %2775 = torch.aten.mul.Scalar %2774, %int8_2673 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_2674 = torch.constant.int 1 + %2776 = torch.aten.add.Tensor %2775, %2735, %int1_2674 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_2675 = torch.constant.int 32 + %2777 = torch.aten.mul.Scalar %2776, %int32_2675 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_2676 = torch.constant.int 1 + %2778 = torch.aten.add.Tensor %2777, %2732, %int1_2676 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_2677 = torch.constant.int 5 + %2779 = torch.prims.convert_element_type %2704, %int5_2677 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %2780 = torch.prim.ListConstruct %2778 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_2678 = torch.constant.bool false + %2781 = torch.aten.index_put %2764, %2780, %2779, %false_2678 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2781, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_2679 = torch.constant.int 32 + %int2_2680 = torch.constant.int 2 + %int8_2681 = torch.constant.int 8 + %int32_2682 = torch.constant.int 32 + %int128_2683 = torch.constant.int 128 + %2782 = torch.prim.ListConstruct %456, %int32_2679, %int2_2680, %int8_2681, %int32_2682, %int128_2683 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2783 = torch.aten.view %2781, %2782 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2783, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_2684 = torch.constant.int 2097152 + %2784 = torch.prim.ListConstruct %456, %int2097152_2684 : (!torch.int, !torch.int) -> !torch.list + %2785 = torch.aten.view %2783, %2784 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2785, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_2685 = torch.constant.none + %2786 = torch.aten.clone %148, %none_2685 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2787 = torch.aten.detach %2786 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2788 = torch.aten.detach %2787 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2789 = torch.aten.detach %2788 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_2686 = torch.constant.none + %2790 = torch.aten.clone %149, %none_2686 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2791 = torch.aten.detach %2790 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2792 = torch.aten.detach %2791 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2793 = torch.aten.detach %2792 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_2687 = torch.constant.none + %2794 = torch.aten.clone %150, %none_2687 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2795 = torch.aten.detach %2794 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2796 = torch.aten.detach %2795 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2797 = torch.aten.detach %2796 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_2688 = torch.constant.int 32 + %int2_2689 = torch.constant.int 2 + %int8_2690 = torch.constant.int 8 + %int32_2691 = torch.constant.int 32 + %int128_2692 = torch.constant.int 128 + %2798 = torch.prim.ListConstruct %456, %int32_2688, %int2_2689, %int8_2690, %int32_2691, %int128_2692 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2799 = torch.aten.view %2785, %2798 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2799, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %2800 = torch_c.to_builtin_tensor %2799 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %2801 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_2693 = tensor.cast %2801 : tensor<4x?xi64> to tensor + %2802 = torch_c.to_builtin_tensor %2789 : !torch.vtensor<[],si64> -> tensor + %2803 = torch_c.to_builtin_tensor %2793 : !torch.vtensor<[],si64> -> tensor + %2804 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%2800, %cast_2693, %2802, %2803) : (tensor, tensor, tensor, tensor) -> tensor + %cast_2694 = tensor.cast %2804 : tensor to tensor<4x?x8x32x128xf16> + %2805 = torch_c.from_builtin_tensor %cast_2694 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %2805, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %2806 = torch_c.to_builtin_tensor %2799 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %2807 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_2695 = tensor.cast %2807 : tensor<4x?xi64> to tensor + %2808 = torch_c.to_builtin_tensor %2789 : !torch.vtensor<[],si64> -> tensor + %2809 = torch_c.to_builtin_tensor %2797 : !torch.vtensor<[],si64> -> tensor + %2810 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%2806, %cast_2695, %2808, %2809) : (tensor, tensor, tensor, tensor) -> tensor + %cast_2696 = tensor.cast %2810 : tensor to tensor<4x?x8x32x128xf16> + %2811 = torch_c.from_builtin_tensor %cast_2696 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %2811, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_2697 = torch.constant.int 2 + %int3_2698 = torch.constant.int 3 + %2812 = torch.aten.transpose.int %2805, %int2_2697, %int3_2698 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2812, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_2699 = torch.constant.int 0 + %2813 = torch.aten.clone %2812, %int0_2699 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2813, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_2700 = torch.constant.int 4 + %int8_2701 = torch.constant.int 8 + %int128_2702 = torch.constant.int 128 + %2814 = torch.prim.ListConstruct %int4_2700, %457, %int8_2701, %int128_2702 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2815 = torch.aten._unsafe_view %2813, %2814 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2815, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_2703 = torch.constant.int 2 + %int3_2704 = torch.constant.int 3 + %2816 = torch.aten.transpose.int %2811, %int2_2703, %int3_2704 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2816, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_2705 = torch.constant.int 0 + %2817 = torch.aten.clone %2816, %int0_2705 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %2817, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_2706 = torch.constant.int 4 + %int8_2707 = torch.constant.int 8 + %int128_2708 = torch.constant.int 128 + %2818 = torch.prim.ListConstruct %int4_2706, %457, %int8_2707, %int128_2708 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2819 = torch.aten._unsafe_view %2817, %2818 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %2819, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_2709 = torch.constant.int -2 + %2820 = torch.aten.unsqueeze %2815, %int-2_2709 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2820, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_2710 = torch.constant.int 4 + %int8_2711 = torch.constant.int 8 + %int4_2712 = torch.constant.int 4 + %int128_2713 = torch.constant.int 128 + %2821 = torch.prim.ListConstruct %int4_2710, %457, %int8_2711, %int4_2712, %int128_2713 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_2714 = torch.constant.bool false + %2822 = torch.aten.expand %2820, %2821, %false_2714 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2822, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_2715 = torch.constant.int 0 + %2823 = torch.aten.clone %2822, %int0_2715 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2823, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_2716 = torch.constant.int 4 - %int128_2717 = torch.constant.int 128 - %2382 = torch.prim.ListConstruct %int4_2714, %2381, %int8_2715, %int4_2716, %int128_2717 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2718 = torch.constant.bool false - %2383 = torch.aten.expand %2380, %2382, %false_2718 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2383, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_2719 = torch.constant.int 0 - %2384 = torch.aten.clone %2383, %int0_2719 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2384, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int32_2717 = torch.constant.int 32 + %int128_2718 = torch.constant.int 128 + %2824 = torch.prim.ListConstruct %int4_2716, %457, %int32_2717, %int128_2718 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2825 = torch.aten._unsafe_view %2823, %2824 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2825, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_2719 = torch.constant.int -2 + %2826 = torch.aten.unsqueeze %2819, %int-2_2719 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %2826, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> %int4_2720 = torch.constant.int 4 - %int32_2721 = torch.constant.int 32 - %int128_2722 = torch.constant.int 128 - %2385 = torch.prim.ListConstruct %int4_2720, %2381, %int32_2721, %int128_2722 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2386 = torch.aten._unsafe_view %2384, %2385 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2386, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_2723 = torch.constant.int 1 - %int2_2724 = torch.constant.int 2 - %2387 = torch.aten.transpose.int %2267, %int1_2723, %int2_2724 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_2725 = torch.constant.int 1 - %int2_2726 = torch.constant.int 2 - %2388 = torch.aten.transpose.int %2379, %int1_2725, %int2_2726 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2388, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_2727 = torch.constant.int 1 - %int2_2728 = torch.constant.int 2 - %2389 = torch.aten.transpose.int %2386, %int1_2727, %int2_2728 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2389, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_2729 = torch.constant.float 0.000000e+00 - %false_2730 = torch.constant.bool false - %none_2731 = torch.constant.none - %2390:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2387, %2388, %2389, %float0.000000e00_2729, %false_2730, %368, %none_2731) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_2732 = torch.constant.int 1 - %int2_2733 = torch.constant.int 2 - %2391 = torch.aten.transpose.int %2390#0, %int1_2732, %int2_2733 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_2734 = torch.constant.int 4 - %int1_2735 = torch.constant.int 1 - %int4096_2736 = torch.constant.int 4096 - %2392 = torch.prim.ListConstruct %int4_2734, %int1_2735, %int4096_2736 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2393 = torch.aten.view %2391, %2392 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_2737 = torch.constant.int -2 - %int-1_2738 = torch.constant.int -1 - %2394 = torch.aten.transpose.int %106, %int-2_2737, %int-1_2738 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_2739 = torch.constant.int 4 - %int4096_2740 = torch.constant.int 4096 - %2395 = torch.prim.ListConstruct %int4_2739, %int4096_2740 : (!torch.int, !torch.int) -> !torch.list - %2396 = torch.aten.view %2393, %2395 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2397 = torch.aten.mm %2396, %2394 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_2741 = torch.constant.int 4 - %int1_2742 = torch.constant.int 1 - %int4096_2743 = torch.constant.int 4096 - %2398 = torch.prim.ListConstruct %int4_2741, %int1_2742, %int4096_2743 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2399 = torch.aten.view %2397, %2398 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_2744 = torch.constant.int 1 - %2400 = torch.aten.add.Tensor %2227, %2399, %int1_2744 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_2745 = torch.constant.int 6 - %2401 = torch.prims.convert_element_type %2400, %int6_2745 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_2746 = torch.constant.int 2 - %2402 = torch.aten.pow.Tensor_Scalar %2401, %int2_2746 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_2747 = torch.constant.int -1 - %2403 = torch.prim.ListConstruct %int-1_2747 : (!torch.int) -> !torch.list - %true_2748 = torch.constant.bool true - %none_2749 = torch.constant.none - %2404 = torch.aten.mean.dim %2402, %2403, %true_2748, %none_2749 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_2750 = torch.constant.float 9.9999997473787516E-6 + %int8_2721 = torch.constant.int 8 + %int4_2722 = torch.constant.int 4 + %int128_2723 = torch.constant.int 128 + %2827 = torch.prim.ListConstruct %int4_2720, %457, %int8_2721, %int4_2722, %int128_2723 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_2724 = torch.constant.bool false + %2828 = torch.aten.expand %2826, %2827, %false_2724 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2828, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_2725 = torch.constant.int 0 + %2829 = torch.aten.clone %2828, %int0_2725 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %2829, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_2726 = torch.constant.int 4 + %int32_2727 = torch.constant.int 32 + %int128_2728 = torch.constant.int 128 + %2830 = torch.prim.ListConstruct %int4_2726, %457, %int32_2727, %int128_2728 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2831 = torch.aten._unsafe_view %2829, %2830 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %2831, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_2729 = torch.constant.int 1 + %int2_2730 = torch.constant.int 2 + %2832 = torch.aten.transpose.int %2714, %int1_2729, %int2_2730 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_2731 = torch.constant.int 1 + %int2_2732 = torch.constant.int 2 + %2833 = torch.aten.transpose.int %2825, %int1_2731, %int2_2732 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2833, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_2733 = torch.constant.int 1 + %int2_2734 = torch.constant.int 2 + %2834 = torch.aten.transpose.int %2831, %int1_2733, %int2_2734 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %2834, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_2735 = torch.constant.float 0.000000e+00 + %false_2736 = torch.constant.bool false + %none_2737 = torch.constant.none + %2835:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2832, %2833, %2834, %float0.000000e00_2735, %false_2736, %470, %none_2737) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_2738 = torch.constant.int 1 + %int2_2739 = torch.constant.int 2 + %2836 = torch.aten.transpose.int %2835#0, %int1_2738, %int2_2739 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_2740 = torch.constant.int 4 + %int1_2741 = torch.constant.int 1 + %int4096_2742 = torch.constant.int 4096 + %2837 = torch.prim.ListConstruct %int4_2740, %int1_2741, %int4096_2742 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2838 = torch.aten.view %2836, %2837 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_2743 = torch.constant.int -2 + %int-1_2744 = torch.constant.int -1 + %2839 = torch.aten.transpose.int %151, %int-2_2743, %int-1_2744 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2745 = torch.constant.int 5 + %2840 = torch.prims.convert_element_type %2839, %int5_2745 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_2746 = torch.constant.int 4 + %int4096_2747 = torch.constant.int 4096 + %2841 = torch.prim.ListConstruct %int4_2746, %int4096_2747 : (!torch.int, !torch.int) -> !torch.list + %2842 = torch.aten.view %2838, %2841 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2843 = torch.aten.mm %2842, %2840 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_2748 = torch.constant.int 4 + %int1_2749 = torch.constant.int 1 + %int4096_2750 = torch.constant.int 4096 + %2844 = torch.prim.ListConstruct %int4_2748, %int1_2749, %int4096_2750 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2845 = torch.aten.view %2843, %2844 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_2751 = torch.constant.int 1 - %2405 = torch.aten.add.Scalar %2404, %float9.999990e-06_2750, %int1_2751 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %2406 = torch.aten.rsqrt %2405 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %2407 = torch.aten.mul.Tensor %2401, %2406 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_2752 = torch.constant.int 5 - %2408 = torch.prims.convert_element_type %2407, %int5_2752 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %2409 = torch.aten.mul.Tensor %107, %2408 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_2753 = torch.constant.int 5 - %2410 = torch.prims.convert_element_type %2409, %int5_2753 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_2754 = torch.constant.int -2 - %int-1_2755 = torch.constant.int -1 - %2411 = torch.aten.transpose.int %108, %int-2_2754, %int-1_2755 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_2756 = torch.constant.int 4 - %int4096_2757 = torch.constant.int 4096 - %2412 = torch.prim.ListConstruct %int4_2756, %int4096_2757 : (!torch.int, !torch.int) -> !torch.list - %2413 = torch.aten.view %2410, %2412 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2414 = torch.aten.mm %2413, %2411 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_2758 = torch.constant.int 4 - %int1_2759 = torch.constant.int 1 - %int14336_2760 = torch.constant.int 14336 - %2415 = torch.prim.ListConstruct %int4_2758, %int1_2759, %int14336_2760 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2416 = torch.aten.view %2414, %2415 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %2417 = torch.aten.silu %2416 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %2846 = torch.aten.add.Tensor %2667, %2845, %int1_2751 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_2752 = torch.constant.int 6 + %2847 = torch.prims.convert_element_type %2846, %int6_2752 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_2753 = torch.constant.int 2 + %2848 = torch.aten.pow.Tensor_Scalar %2847, %int2_2753 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_2754 = torch.constant.int -1 + %2849 = torch.prim.ListConstruct %int-1_2754 : (!torch.int) -> !torch.list + %true_2755 = torch.constant.bool true + %none_2756 = torch.constant.none + %2850 = torch.aten.mean.dim %2848, %2849, %true_2755, %none_2756 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_2757 = torch.constant.float 9.9999997473787516E-6 + %int1_2758 = torch.constant.int 1 + %2851 = torch.aten.add.Scalar %2850, %float9.999990e-06_2757, %int1_2758 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %2852 = torch.aten.rsqrt %2851 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %2853 = torch.aten.mul.Tensor %2847, %2852 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_2759 = torch.constant.int 5 + %2854 = torch.prims.convert_element_type %2853, %int5_2759 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %2855 = torch.aten.mul.Tensor %152, %2854 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_2760 = torch.constant.int 5 + %2856 = torch.prims.convert_element_type %2855, %int5_2760 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> %int-2_2761 = torch.constant.int -2 %int-1_2762 = torch.constant.int -1 - %2418 = torch.aten.transpose.int %109, %int-2_2761, %int-1_2762 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_2763 = torch.constant.int 4 - %int4096_2764 = torch.constant.int 4096 - %2419 = torch.prim.ListConstruct %int4_2763, %int4096_2764 : (!torch.int, !torch.int) -> !torch.list - %2420 = torch.aten.view %2410, %2419 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2421 = torch.aten.mm %2420, %2418 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_2765 = torch.constant.int 4 - %int1_2766 = torch.constant.int 1 - %int14336_2767 = torch.constant.int 14336 - %2422 = torch.prim.ListConstruct %int4_2765, %int1_2766, %int14336_2767 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2423 = torch.aten.view %2421, %2422 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %2424 = torch.aten.mul.Tensor %2417, %2423 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_2768 = torch.constant.int -2 - %int-1_2769 = torch.constant.int -1 - %2425 = torch.aten.transpose.int %110, %int-2_2768, %int-1_2769 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_2770 = torch.constant.int 4 - %int14336_2771 = torch.constant.int 14336 - %2426 = torch.prim.ListConstruct %int4_2770, %int14336_2771 : (!torch.int, !torch.int) -> !torch.list - %2427 = torch.aten.view %2424, %2426 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %2428 = torch.aten.mm %2427, %2425 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %2857 = torch.aten.transpose.int %153, %int-2_2761, %int-1_2762 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_2763 = torch.constant.int 5 + %2858 = torch.prims.convert_element_type %2857, %int5_2763 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_2764 = torch.constant.int 4 + %int4096_2765 = torch.constant.int 4096 + %2859 = torch.prim.ListConstruct %int4_2764, %int4096_2765 : (!torch.int, !torch.int) -> !torch.list + %2860 = torch.aten.view %2856, %2859 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2861 = torch.aten.mm %2860, %2858 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_2766 = torch.constant.int 4 + %int1_2767 = torch.constant.int 1 + %int14336_2768 = torch.constant.int 14336 + %2862 = torch.prim.ListConstruct %int4_2766, %int1_2767, %int14336_2768 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2863 = torch.aten.view %2861, %2862 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %2864 = torch.aten.silu %2863 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_2769 = torch.constant.int -2 + %int-1_2770 = torch.constant.int -1 + %2865 = torch.aten.transpose.int %154, %int-2_2769, %int-1_2770 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_2771 = torch.constant.int 5 + %2866 = torch.prims.convert_element_type %2865, %int5_2771 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> %int4_2772 = torch.constant.int 4 - %int1_2773 = torch.constant.int 1 - %int4096_2774 = torch.constant.int 4096 - %2429 = torch.prim.ListConstruct %int4_2772, %int1_2773, %int4096_2774 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2430 = torch.aten.view %2428, %2429 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int4096_2773 = torch.constant.int 4096 + %2867 = torch.prim.ListConstruct %int4_2772, %int4096_2773 : (!torch.int, !torch.int) -> !torch.list + %2868 = torch.aten.view %2856, %2867 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2869 = torch.aten.mm %2868, %2866 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_2774 = torch.constant.int 4 %int1_2775 = torch.constant.int 1 - %2431 = torch.aten.add.Tensor %2400, %2430, %int1_2775 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_2776 = torch.constant.int 6 - %2432 = torch.prims.convert_element_type %2431, %int6_2776 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_2777 = torch.constant.int 2 - %2433 = torch.aten.pow.Tensor_Scalar %2432, %int2_2777 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int14336_2776 = torch.constant.int 14336 + %2870 = torch.prim.ListConstruct %int4_2774, %int1_2775, %int14336_2776 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2871 = torch.aten.view %2869, %2870 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %2872 = torch.aten.mul.Tensor %2864, %2871 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_2777 = torch.constant.int -2 %int-1_2778 = torch.constant.int -1 - %2434 = torch.prim.ListConstruct %int-1_2778 : (!torch.int) -> !torch.list - %true_2779 = torch.constant.bool true - %none_2780 = torch.constant.none - %2435 = torch.aten.mean.dim %2433, %2434, %true_2779, %none_2780 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_2781 = torch.constant.float 9.9999997473787516E-6 - %int1_2782 = torch.constant.int 1 - %2436 = torch.aten.add.Scalar %2435, %float9.999990e-06_2781, %int1_2782 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %2437 = torch.aten.rsqrt %2436 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %2438 = torch.aten.mul.Tensor %2432, %2437 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_2783 = torch.constant.int 5 - %2439 = torch.prims.convert_element_type %2438, %int5_2783 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %2440 = torch.aten.mul.Tensor %111, %2439 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_2784 = torch.constant.int 5 - %2441 = torch.prims.convert_element_type %2440, %int5_2784 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_2785 = torch.constant.int -2 - %int-1_2786 = torch.constant.int -1 - %2442 = torch.aten.transpose.int %112, %int-2_2785, %int-1_2786 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_2787 = torch.constant.int 4 - %int4096_2788 = torch.constant.int 4096 - %2443 = torch.prim.ListConstruct %int4_2787, %int4096_2788 : (!torch.int, !torch.int) -> !torch.list - %2444 = torch.aten.view %2441, %2443 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2445 = torch.aten.mm %2444, %2442 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_2789 = torch.constant.int 4 - %int1_2790 = torch.constant.int 1 - %int4096_2791 = torch.constant.int 4096 - %2446 = torch.prim.ListConstruct %int4_2789, %int1_2790, %int4096_2791 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2447 = torch.aten.view %2445, %2446 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_2792 = torch.constant.int -2 - %int-1_2793 = torch.constant.int -1 - %2448 = torch.aten.transpose.int %113, %int-2_2792, %int-1_2793 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_2794 = torch.constant.int 4 - %int4096_2795 = torch.constant.int 4096 - %2449 = torch.prim.ListConstruct %int4_2794, %int4096_2795 : (!torch.int, !torch.int) -> !torch.list - %2450 = torch.aten.view %2441, %2449 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2451 = torch.aten.mm %2450, %2448 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_2796 = torch.constant.int 4 - %int1_2797 = torch.constant.int 1 - %int1024_2798 = torch.constant.int 1024 - %2452 = torch.prim.ListConstruct %int4_2796, %int1_2797, %int1024_2798 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2453 = torch.aten.view %2451, %2452 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_2799 = torch.constant.int -2 - %int-1_2800 = torch.constant.int -1 - %2454 = torch.aten.transpose.int %114, %int-2_2799, %int-1_2800 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_2801 = torch.constant.int 4 + %2873 = torch.aten.transpose.int %155, %int-2_2777, %int-1_2778 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_2779 = torch.constant.int 5 + %2874 = torch.prims.convert_element_type %2873, %int5_2779 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_2780 = torch.constant.int 4 + %int14336_2781 = torch.constant.int 14336 + %2875 = torch.prim.ListConstruct %int4_2780, %int14336_2781 : (!torch.int, !torch.int) -> !torch.list + %2876 = torch.aten.view %2872, %2875 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %2877 = torch.aten.mm %2876, %2874 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_2782 = torch.constant.int 4 + %int1_2783 = torch.constant.int 1 + %int4096_2784 = torch.constant.int 4096 + %2878 = torch.prim.ListConstruct %int4_2782, %int1_2783, %int4096_2784 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2879 = torch.aten.view %2877, %2878 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_2785 = torch.constant.int 1 + %2880 = torch.aten.add.Tensor %2846, %2879, %int1_2785 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_2786 = torch.constant.int 6 + %2881 = torch.prims.convert_element_type %2880, %int6_2786 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_2787 = torch.constant.int 2 + %2882 = torch.aten.pow.Tensor_Scalar %2881, %int2_2787 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_2788 = torch.constant.int -1 + %2883 = torch.prim.ListConstruct %int-1_2788 : (!torch.int) -> !torch.list + %true_2789 = torch.constant.bool true + %none_2790 = torch.constant.none + %2884 = torch.aten.mean.dim %2882, %2883, %true_2789, %none_2790 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_2791 = torch.constant.float 9.9999997473787516E-6 + %int1_2792 = torch.constant.int 1 + %2885 = torch.aten.add.Scalar %2884, %float9.999990e-06_2791, %int1_2792 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %2886 = torch.aten.rsqrt %2885 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %2887 = torch.aten.mul.Tensor %2881, %2886 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_2793 = torch.constant.int 5 + %2888 = torch.prims.convert_element_type %2887, %int5_2793 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %2889 = torch.aten.mul.Tensor %156, %2888 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_2794 = torch.constant.int 5 + %2890 = torch.prims.convert_element_type %2889, %int5_2794 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_2795 = torch.constant.int -2 + %int-1_2796 = torch.constant.int -1 + %2891 = torch.aten.transpose.int %157, %int-2_2795, %int-1_2796 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2797 = torch.constant.int 5 + %2892 = torch.prims.convert_element_type %2891, %int5_2797 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_2798 = torch.constant.int 4 + %int4096_2799 = torch.constant.int 4096 + %2893 = torch.prim.ListConstruct %int4_2798, %int4096_2799 : (!torch.int, !torch.int) -> !torch.list + %2894 = torch.aten.view %2890, %2893 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2895 = torch.aten.mm %2894, %2892 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_2800 = torch.constant.int 4 + %int1_2801 = torch.constant.int 1 %int4096_2802 = torch.constant.int 4096 - %2455 = torch.prim.ListConstruct %int4_2801, %int4096_2802 : (!torch.int, !torch.int) -> !torch.list - %2456 = torch.aten.view %2441, %2455 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2457 = torch.aten.mm %2456, %2454 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_2803 = torch.constant.int 4 - %int1_2804 = torch.constant.int 1 - %int1024_2805 = torch.constant.int 1024 - %2458 = torch.prim.ListConstruct %int4_2803, %int1_2804, %int1024_2805 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2459 = torch.aten.view %2457, %2458 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %2896 = torch.prim.ListConstruct %int4_2800, %int1_2801, %int4096_2802 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2897 = torch.aten.view %2895, %2896 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_2803 = torch.constant.int -2 + %int-1_2804 = torch.constant.int -1 + %2898 = torch.aten.transpose.int %158, %int-2_2803, %int-1_2804 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_2805 = torch.constant.int 5 + %2899 = torch.prims.convert_element_type %2898, %int5_2805 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> %int4_2806 = torch.constant.int 4 - %int1_2807 = torch.constant.int 1 - %int32_2808 = torch.constant.int 32 - %int128_2809 = torch.constant.int 128 - %2460 = torch.prim.ListConstruct %int4_2806, %int1_2807, %int32_2808, %int128_2809 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2461 = torch.aten.view %2447, %2460 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_2810 = torch.constant.int 4 - %int1_2811 = torch.constant.int 1 - %int8_2812 = torch.constant.int 8 - %int128_2813 = torch.constant.int 128 - %2462 = torch.prim.ListConstruct %int4_2810, %int1_2811, %int8_2812, %int128_2813 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2463 = torch.aten.view %2453, %2462 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4096_2807 = torch.constant.int 4096 + %2900 = torch.prim.ListConstruct %int4_2806, %int4096_2807 : (!torch.int, !torch.int) -> !torch.list + %2901 = torch.aten.view %2890, %2900 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2902 = torch.aten.mm %2901, %2899 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_2808 = torch.constant.int 4 + %int1_2809 = torch.constant.int 1 + %int1024_2810 = torch.constant.int 1024 + %2903 = torch.prim.ListConstruct %int4_2808, %int1_2809, %int1024_2810 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2904 = torch.aten.view %2902, %2903 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_2811 = torch.constant.int -2 + %int-1_2812 = torch.constant.int -1 + %2905 = torch.aten.transpose.int %159, %int-2_2811, %int-1_2812 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_2813 = torch.constant.int 5 + %2906 = torch.prims.convert_element_type %2905, %int5_2813 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> %int4_2814 = torch.constant.int 4 - %int1_2815 = torch.constant.int 1 - %int8_2816 = torch.constant.int 8 - %int128_2817 = torch.constant.int 128 - %2464 = torch.prim.ListConstruct %int4_2814, %int1_2815, %int8_2816, %int128_2817 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2465 = torch.aten.view %2459, %2464 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_2818 = torch.constant.int 6 - %2466 = torch.prims.convert_element_type %2461, %int6_2818 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %2467 = torch_c.to_builtin_tensor %2466 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %2468 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %2469 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%2467, %2468) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %2470 = torch_c.from_builtin_tensor %2469 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_2819 = torch.constant.int 5 - %2471 = torch.prims.convert_element_type %2470, %int5_2819 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_2820 = torch.constant.int 6 - %2472 = torch.prims.convert_element_type %2463, %int6_2820 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %2473 = torch_c.to_builtin_tensor %2472 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %2474 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %2475 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%2473, %2474) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %2476 = torch_c.from_builtin_tensor %2475 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_2821 = torch.constant.int 5 - %2477 = torch.prims.convert_element_type %2476, %int5_2821 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_2822 = torch.constant.int 32 - %2478 = torch.aten.floor_divide.Scalar %arg2, %int32_2822 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_2823 = torch.constant.int 1 - %2479 = torch.aten.unsqueeze %2478, %int1_2823 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int4096_2815 = torch.constant.int 4096 + %2907 = torch.prim.ListConstruct %int4_2814, %int4096_2815 : (!torch.int, !torch.int) -> !torch.list + %2908 = torch.aten.view %2890, %2907 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %2909 = torch.aten.mm %2908, %2906 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_2816 = torch.constant.int 4 + %int1_2817 = torch.constant.int 1 + %int1024_2818 = torch.constant.int 1024 + %2910 = torch.prim.ListConstruct %int4_2816, %int1_2817, %int1024_2818 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2911 = torch.aten.view %2909, %2910 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_2819 = torch.constant.int 4 + %int1_2820 = torch.constant.int 1 + %int32_2821 = torch.constant.int 32 + %int128_2822 = torch.constant.int 128 + %2912 = torch.prim.ListConstruct %int4_2819, %int1_2820, %int32_2821, %int128_2822 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2913 = torch.aten.view %2897, %2912 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_2823 = torch.constant.int 4 %int1_2824 = torch.constant.int 1 - %false_2825 = torch.constant.bool false - %2480 = torch.aten.gather %arg3, %int1_2824, %2479, %false_2825 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_2826 = torch.constant.int 32 - %2481 = torch.aten.remainder.Scalar %arg2, %int32_2826 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_2827 = torch.constant.int 1 - %2482 = torch.aten.unsqueeze %2481, %int1_2827 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_2828 = torch.constant.none - %2483 = torch.aten.clone %115, %none_2828 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_2829 = torch.constant.int 0 - %2484 = torch.aten.unsqueeze %2483, %int0_2829 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_2830 = torch.constant.int 4 + %int8_2825 = torch.constant.int 8 + %int128_2826 = torch.constant.int 128 + %2914 = torch.prim.ListConstruct %int4_2823, %int1_2824, %int8_2825, %int128_2826 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2915 = torch.aten.view %2904, %2914 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_2827 = torch.constant.int 4 + %int1_2828 = torch.constant.int 1 + %int8_2829 = torch.constant.int 8 + %int128_2830 = torch.constant.int 128 + %2916 = torch.prim.ListConstruct %int4_2827, %int1_2828, %int8_2829, %int128_2830 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2917 = torch.aten.view %2911, %2916 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> %int1_2831 = torch.constant.int 1 - %2485 = torch.prim.ListConstruct %int4_2830, %int1_2831 : (!torch.int, !torch.int) -> !torch.list - %int1_2832 = torch.constant.int 1 - %int1_2833 = torch.constant.int 1 - %2486 = torch.prim.ListConstruct %int1_2832, %int1_2833 : (!torch.int, !torch.int) -> !torch.list - %int4_2834 = torch.constant.int 4 - %int0_2835 = torch.constant.int 0 - %cpu_2836 = torch.constant.device "cpu" - %false_2837 = torch.constant.bool false - %2487 = torch.aten.empty_strided %2485, %2486, %int4_2834, %int0_2835, %cpu_2836, %false_2837 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int10 = torch.constant.int 10 - %2488 = torch.aten.fill.Scalar %2487, %int10 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_2838 = torch.constant.int 4 - %int1_2839 = torch.constant.int 1 - %2489 = torch.prim.ListConstruct %int4_2838, %int1_2839 : (!torch.int, !torch.int) -> !torch.list - %2490 = torch.aten.repeat %2484, %2489 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_2840 = torch.constant.int 32 - %2491 = torch.aten.mul.Scalar %2480, %int32_2840 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2841 = torch.constant.int 1 - %2492 = torch.aten.add.Tensor %2491, %2488, %int1_2841 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_2842 = torch.constant.int 2 - %2493 = torch.aten.mul.Scalar %2492, %int2_2842 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int2_2832 = torch.constant.int 2 + %2918 = torch.aten.transpose.int %2913, %int1_2831, %int2_2832 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %2919 = torch.aten.mul.Tensor %2918, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_2833 = torch.constant.int 3 + %int0_2834 = torch.constant.int 0 + %int64_2835 = torch.constant.int 64 + %int1_2836 = torch.constant.int 1 + %2920 = torch.aten.slice.Tensor %2918, %int3_2833, %int0_2834, %int64_2835, %int1_2836 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_2837 = torch.constant.int 3 + %int64_2838 = torch.constant.int 64 + %int9223372036854775807_2839 = torch.constant.int 9223372036854775807 + %int1_2840 = torch.constant.int 1 + %2921 = torch.aten.slice.Tensor %2918, %int3_2837, %int64_2838, %int9223372036854775807_2839, %int1_2840 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %2922 = torch.aten.neg %2921 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %2923 = torch.prim.ListConstruct %2922, %2920 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_2841 = torch.constant.int -1 + %2924 = torch.aten.cat %2923, %int-1_2841 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %2925 = torch.aten.mul.Tensor %2924, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_2842 = torch.constant.int 1 + %2926 = torch.aten.add.Tensor %2919, %2925, %int1_2842 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_2843 = torch.constant.int 1 - %2494 = torch.aten.add.Tensor %2493, %2490, %int1_2843 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_2844 = torch.constant.int 32 - %2495 = torch.aten.mul.Scalar %2494, %int32_2844 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int2_2844 = torch.constant.int 2 + %2927 = torch.aten.transpose.int %2926, %int1_2843, %int2_2844 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> %int1_2845 = torch.constant.int 1 - %2496 = torch.aten.add.Tensor %2495, %2482, %int1_2845 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_2846 = torch.constant.int 32 - %int2_2847 = torch.constant.int 2 - %int32_2848 = torch.constant.int 32 - %int8_2849 = torch.constant.int 8 - %int128_2850 = torch.constant.int 128 - %2497 = torch.prim.ListConstruct %437, %int32_2846, %int2_2847, %int32_2848, %int8_2849, %int128_2850 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2498 = torch.aten.view %2334, %2497 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2498, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2851 = torch.constant.int 32 - %2499 = torch.aten.mul.int %437, %int32_2851 : !torch.int, !torch.int -> !torch.int - %int2_2852 = torch.constant.int 2 - %2500 = torch.aten.mul.int %2499, %int2_2852 : !torch.int, !torch.int -> !torch.int - %int32_2853 = torch.constant.int 32 - %2501 = torch.aten.mul.int %2500, %int32_2853 : !torch.int, !torch.int -> !torch.int - %int8_2854 = torch.constant.int 8 - %int128_2855 = torch.constant.int 128 - %2502 = torch.prim.ListConstruct %2501, %int8_2854, %int128_2855 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2503 = torch.aten.view %2498, %2502 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2503, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %2504 = torch.prim.ListConstruct %2496 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_2856 = torch.constant.bool false - %2505 = torch.aten.index_put %2503, %2504, %2477, %false_2856 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2505, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_2857 = torch.constant.int 32 + %int2_2846 = torch.constant.int 2 + %2928 = torch.aten.transpose.int %2915, %int1_2845, %int2_2846 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %2929 = torch.aten.mul.Tensor %2928, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_2847 = torch.constant.int 3 + %int0_2848 = torch.constant.int 0 + %int64_2849 = torch.constant.int 64 + %int1_2850 = torch.constant.int 1 + %2930 = torch.aten.slice.Tensor %2928, %int3_2847, %int0_2848, %int64_2849, %int1_2850 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_2851 = torch.constant.int 3 + %int64_2852 = torch.constant.int 64 + %int9223372036854775807_2853 = torch.constant.int 9223372036854775807 + %int1_2854 = torch.constant.int 1 + %2931 = torch.aten.slice.Tensor %2928, %int3_2851, %int64_2852, %int9223372036854775807_2853, %int1_2854 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %2932 = torch.aten.neg %2931 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %2933 = torch.prim.ListConstruct %2932, %2930 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_2855 = torch.constant.int -1 + %2934 = torch.aten.cat %2933, %int-1_2855 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %2935 = torch.aten.mul.Tensor %2934, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_2856 = torch.constant.int 1 + %2936 = torch.aten.add.Tensor %2929, %2935, %int1_2856 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_2857 = torch.constant.int 1 %int2_2858 = torch.constant.int 2 + %2937 = torch.aten.transpose.int %2936, %int1_2857, %int2_2858 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> %int32_2859 = torch.constant.int 32 - %int8_2860 = torch.constant.int 8 - %int128_2861 = torch.constant.int 128 - %2506 = torch.prim.ListConstruct %437, %int32_2857, %int2_2858, %int32_2859, %int8_2860, %int128_2861 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2507 = torch.aten.view %2505, %2506 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2507, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2862 = torch.constant.int 2097152 - %2508 = torch.prim.ListConstruct %437, %int2097152_2862 : (!torch.int, !torch.int) -> !torch.list - %2509 = torch.aten.view %2507, %2508 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2509, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_2863 = torch.constant.int 32 - %int2_2864 = torch.constant.int 2 - %int32_2865 = torch.constant.int 32 - %int8_2866 = torch.constant.int 8 - %int128_2867 = torch.constant.int 128 - %2510 = torch.prim.ListConstruct %437, %int32_2863, %int2_2864, %int32_2865, %int8_2866, %int128_2867 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2511 = torch.aten.view %2509, %2510 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2511, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_2868 = torch.constant.int 8 - %int128_2869 = torch.constant.int 128 - %2512 = torch.prim.ListConstruct %2501, %int8_2868, %int128_2869 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2513 = torch.aten.view %2511, %2512 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2513, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_2870 = torch.constant.int 32 - %2514 = torch.aten.floor_divide.Scalar %arg2, %int32_2870 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_2871 = torch.constant.int 1 - %2515 = torch.aten.unsqueeze %2514, %int1_2871 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2872 = torch.constant.int 1 - %false_2873 = torch.constant.bool false - %2516 = torch.aten.gather %arg3, %int1_2872, %2515, %false_2873 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_2874 = torch.constant.int 32 - %2517 = torch.aten.remainder.Scalar %arg2, %int32_2874 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %2938 = torch.aten.floor_divide.Scalar %arg2, %int32_2859 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_2860 = torch.constant.int 1 + %2939 = torch.aten.unsqueeze %2938, %int1_2860 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_2861 = torch.constant.int 1 + %false_2862 = torch.constant.bool false + %2940 = torch.aten.gather %arg3, %int1_2861, %2939, %false_2862 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_2863 = torch.constant.int 4 + %int1_2864 = torch.constant.int 1 + %int1_2865 = torch.constant.int 1 + %2941 = torch.prim.ListConstruct %int4_2863, %int1_2864, %int1_2865 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2942 = torch.aten.view %2940, %2941 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_2866 = torch.constant.int 32 + %2943 = torch.aten.remainder.Scalar %arg2, %int32_2866 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_2867 = torch.constant.int 4 + %int1_2868 = torch.constant.int 1 + %int1_2869 = torch.constant.int 1 + %2944 = torch.prim.ListConstruct %int4_2867, %int1_2868, %int1_2869 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2945 = torch.aten.view %2943, %2944 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_2870 = torch.constant.int 8 + %none_2871 = torch.constant.none + %none_2872 = torch.constant.none + %cpu_2873 = torch.constant.device "cpu" + %false_2874 = torch.constant.bool false + %2946 = torch.aten.arange %int8_2870, %none_2871, %none_2872, %cpu_2873, %false_2874 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> %int1_2875 = torch.constant.int 1 - %2518 = torch.aten.unsqueeze %2517, %int1_2875 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_2876 = torch.constant.none - %2519 = torch.aten.clone %116, %none_2876 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_2877 = torch.constant.int 0 - %2520 = torch.aten.unsqueeze %2519, %int0_2877 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_2878 = torch.constant.int 4 + %int1_2876 = torch.constant.int 1 + %int8_2877 = torch.constant.int 8 + %2947 = torch.prim.ListConstruct %int1_2875, %int1_2876, %int8_2877 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2948 = torch.aten.view %2946, %2947 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_2878 = torch.constant.none + %2949 = torch.aten.clone %160, %none_2878 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2950 = torch.aten.detach %2949 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2951 = torch.aten.detach %2950 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2952 = torch.aten.detach %2951 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %int1_2879 = torch.constant.int 1 - %2521 = torch.prim.ListConstruct %int4_2878, %int1_2879 : (!torch.int, !torch.int) -> !torch.list %int1_2880 = torch.constant.int 1 %int1_2881 = torch.constant.int 1 - %2522 = torch.prim.ListConstruct %int1_2880, %int1_2881 : (!torch.int, !torch.int) -> !torch.list - %int4_2882 = torch.constant.int 4 - %int0_2883 = torch.constant.int 0 - %cpu_2884 = torch.constant.device "cpu" - %false_2885 = torch.constant.bool false - %2523 = torch.aten.empty_strided %2521, %2522, %int4_2882, %int0_2883, %cpu_2884, %false_2885 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int10_2886 = torch.constant.int 10 - %2524 = torch.aten.fill.Scalar %2523, %int10_2886 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_2887 = torch.constant.int 4 - %int1_2888 = torch.constant.int 1 - %2525 = torch.prim.ListConstruct %int4_2887, %int1_2888 : (!torch.int, !torch.int) -> !torch.list - %2526 = torch.aten.repeat %2520, %2525 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_2889 = torch.constant.int 32 - %2527 = torch.aten.mul.Scalar %2516, %int32_2889 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2890 = torch.constant.int 1 - %2528 = torch.aten.add.Tensor %2527, %2524, %int1_2890 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_2891 = torch.constant.int 2 - %2529 = torch.aten.mul.Scalar %2528, %int2_2891 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2892 = torch.constant.int 1 - %2530 = torch.aten.add.Tensor %2529, %2526, %int1_2892 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_2893 = torch.constant.int 32 - %2531 = torch.aten.mul.Scalar %2530, %int32_2893 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_2894 = torch.constant.int 1 - %2532 = torch.aten.add.Tensor %2531, %2518, %int1_2894 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %2533 = torch.prim.ListConstruct %2532 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_2895 = torch.constant.bool false - %2534 = torch.aten.index_put %2513, %2533, %2465, %false_2895 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2534, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_2896 = torch.constant.int 32 - %int2_2897 = torch.constant.int 2 + %2953 = torch.prim.ListConstruct %int1_2879, %int1_2880, %int1_2881 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2954 = torch.aten.view %2952, %2953 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_2882 = torch.constant.int 32 + %2955 = torch.aten.mul.Scalar %2942, %int32_2882 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int11 = torch.constant.int 11 + %int1_2883 = torch.constant.int 1 + %2956 = torch.aten.add.Scalar %2955, %int11, %int1_2883 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_2884 = torch.constant.int 2 + %2957 = torch.aten.mul.Scalar %2956, %int2_2884 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_2885 = torch.constant.int 1 + %2958 = torch.aten.add.Tensor %2957, %2954, %int1_2885 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_2886 = torch.constant.int 8 + %2959 = torch.aten.mul.Scalar %2958, %int8_2886 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_2887 = torch.constant.int 1 + %2960 = torch.aten.add.Tensor %2959, %2948, %int1_2887 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_2888 = torch.constant.int 32 + %2961 = torch.aten.mul.Scalar %2960, %int32_2888 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_2889 = torch.constant.int 1 + %2962 = torch.aten.add.Tensor %2961, %2945, %int1_2889 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_2890 = torch.constant.int 5 + %2963 = torch.prims.convert_element_type %2937, %int5_2890 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_2891 = torch.constant.int 32 + %int2_2892 = torch.constant.int 2 + %int8_2893 = torch.constant.int 8 + %int32_2894 = torch.constant.int 32 + %int128_2895 = torch.constant.int 128 + %2964 = torch.prim.ListConstruct %456, %int32_2891, %int2_2892, %int8_2893, %int32_2894, %int128_2895 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2965 = torch.aten.view %2785, %2964 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2965, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_2896 = torch.constant.int 128 + %2966 = torch.prim.ListConstruct %596, %int128_2896 : (!torch.int, !torch.int) -> !torch.list + %2967 = torch.aten.view %2965, %2966 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2967, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %2968 = torch.prim.ListConstruct %2962 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_2897 = torch.constant.bool false + %2969 = torch.aten.index_put %2967, %2968, %2963, %false_2897 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2969, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> %int32_2898 = torch.constant.int 32 - %int8_2899 = torch.constant.int 8 - %int128_2900 = torch.constant.int 128 - %2535 = torch.prim.ListConstruct %437, %int32_2896, %int2_2897, %int32_2898, %int8_2899, %int128_2900 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2536 = torch.aten.view %2534, %2535 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2536, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_2901 = torch.constant.int 2097152 - %2537 = torch.prim.ListConstruct %437, %int2097152_2901 : (!torch.int, !torch.int) -> !torch.list - %2538 = torch.aten.view %2536, %2537 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2538, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_2902 = torch.constant.int 4 - %2539 = torch.prim.ListConstruct %int4_2902, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_2903 = torch.constant.int 1 - %2540 = torch.prim.ListConstruct %358, %int1_2903 : (!torch.int, !torch.int) -> !torch.list - %int4_2904 = torch.constant.int 4 - %int0_2905 = torch.constant.int 0 - %cpu_2906 = torch.constant.device "cpu" - %false_2907 = torch.constant.bool false - %2541 = torch.aten.empty_strided %2539, %2540, %int4_2904, %int0_2905, %cpu_2906, %false_2907 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2541, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int10_2908 = torch.constant.int 10 - %2542 = torch.aten.fill.Scalar %2541, %int10_2908 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2542, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_2909 = torch.constant.int 32 - %2543 = torch.aten.mul.Scalar %arg3, %int32_2909 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2543, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_2910 = torch.constant.int 1 - %2544 = torch.aten.add.Tensor %2543, %2542, %int1_2910 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2544, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_2911 = torch.constant.int 4 - %2545 = torch.aten.mul.int %int4_2911, %358 : !torch.int, !torch.int -> !torch.int - %2546 = torch.prim.ListConstruct %2545 : (!torch.int) -> !torch.list - %2547 = torch.aten.view %2544, %2546 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2547, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_2912 = torch.constant.int 32 - %int2_2913 = torch.constant.int 2 + %int2_2899 = torch.constant.int 2 + %int8_2900 = torch.constant.int 8 + %int32_2901 = torch.constant.int 32 + %int128_2902 = torch.constant.int 128 + %2970 = torch.prim.ListConstruct %456, %int32_2898, %int2_2899, %int8_2900, %int32_2901, %int128_2902 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2971 = torch.aten.view %2969, %2970 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2971, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_2903 = torch.constant.int 2097152 + %2972 = torch.prim.ListConstruct %456, %int2097152_2903 : (!torch.int, !torch.int) -> !torch.list + %2973 = torch.aten.view %2971, %2972 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2973, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_2904 = torch.constant.int 32 + %int2_2905 = torch.constant.int 2 + %int8_2906 = torch.constant.int 8 + %int32_2907 = torch.constant.int 32 + %int128_2908 = torch.constant.int 128 + %2974 = torch.prim.ListConstruct %456, %int32_2904, %int2_2905, %int8_2906, %int32_2907, %int128_2908 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2975 = torch.aten.view %2973, %2974 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2975, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_2909 = torch.constant.int 128 + %2976 = torch.prim.ListConstruct %596, %int128_2909 : (!torch.int, !torch.int) -> !torch.list + %2977 = torch.aten.view %2975, %2976 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2977, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_2910 = torch.constant.none + %2978 = torch.aten.clone %161, %none_2910 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %2979 = torch.aten.detach %2978 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2980 = torch.aten.detach %2979 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %2981 = torch.aten.detach %2980 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_2911 = torch.constant.int 1 + %int1_2912 = torch.constant.int 1 + %int1_2913 = torch.constant.int 1 + %2982 = torch.prim.ListConstruct %int1_2911, %int1_2912, %int1_2913 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2983 = torch.aten.view %2981, %2982 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> %int32_2914 = torch.constant.int 32 - %int8_2915 = torch.constant.int 8 - %int128_2916 = torch.constant.int 128 - %2548 = torch.prim.ListConstruct %437, %int32_2912, %int2_2913, %int32_2914, %int8_2915, %int128_2916 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2549 = torch.aten.view %2538, %2548 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2549, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_2917 = torch.constant.int 32 - %2550 = torch.aten.mul.int %437, %int32_2917 : !torch.int, !torch.int -> !torch.int - %int2_2918 = torch.constant.int 2 - %int32_2919 = torch.constant.int 32 - %int8_2920 = torch.constant.int 8 - %int128_2921 = torch.constant.int 128 - %2551 = torch.prim.ListConstruct %2550, %int2_2918, %int32_2919, %int8_2920, %int128_2921 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2552 = torch.aten.view %2549, %2551 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %2552, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_2922 = torch.constant.int 0 - %2553 = torch.aten.index_select %2552, %int0_2922, %2547 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %2553, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_2923 = torch.constant.int 4 - %int2_2924 = torch.constant.int 2 + %2984 = torch.aten.mul.Scalar %2942, %int32_2914 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int11_2915 = torch.constant.int 11 + %int1_2916 = torch.constant.int 1 + %2985 = torch.aten.add.Scalar %2984, %int11_2915, %int1_2916 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_2917 = torch.constant.int 2 + %2986 = torch.aten.mul.Scalar %2985, %int2_2917 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_2918 = torch.constant.int 1 + %2987 = torch.aten.add.Tensor %2986, %2983, %int1_2918 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_2919 = torch.constant.int 8 + %2988 = torch.aten.mul.Scalar %2987, %int8_2919 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_2920 = torch.constant.int 1 + %2989 = torch.aten.add.Tensor %2988, %2948, %int1_2920 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_2921 = torch.constant.int 32 + %2990 = torch.aten.mul.Scalar %2989, %int32_2921 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_2922 = torch.constant.int 1 + %2991 = torch.aten.add.Tensor %2990, %2945, %int1_2922 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_2923 = torch.constant.int 5 + %2992 = torch.prims.convert_element_type %2917, %int5_2923 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %2993 = torch.prim.ListConstruct %2991 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_2924 = torch.constant.bool false + %2994 = torch.aten.index_put %2977, %2993, %2992, %false_2924 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %2994, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> %int32_2925 = torch.constant.int 32 - %int8_2926 = torch.constant.int 8 - %int128_2927 = torch.constant.int 128 - %2554 = torch.prim.ListConstruct %int4_2923, %358, %int2_2924, %int32_2925, %int8_2926, %int128_2927 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2555 = torch.aten.view %2553, %2554 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2555, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_2928 = torch.constant.int 0 - %int0_2929 = torch.constant.int 0 - %int9223372036854775807_2930 = torch.constant.int 9223372036854775807 - %int1_2931 = torch.constant.int 1 - %2556 = torch.aten.slice.Tensor %2555, %int0_2928, %int0_2929, %int9223372036854775807_2930, %int1_2931 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2556, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_2932 = torch.constant.int 1 - %int0_2933 = torch.constant.int 0 - %int9223372036854775807_2934 = torch.constant.int 9223372036854775807 - %int1_2935 = torch.constant.int 1 - %2557 = torch.aten.slice.Tensor %2556, %int1_2932, %int0_2933, %int9223372036854775807_2934, %int1_2935 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2557, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_2936 = torch.constant.int 2 - %int0_2937 = torch.constant.int 0 - %2558 = torch.aten.select.int %2557, %int2_2936, %int0_2937 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2558, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_2938 = torch.constant.int 32 - %2559 = torch.aten.mul.int %358, %int32_2938 : !torch.int, !torch.int -> !torch.int - %int2_2939 = torch.constant.int 2 - %int0_2940 = torch.constant.int 0 - %int1_2941 = torch.constant.int 1 - %2560 = torch.aten.slice.Tensor %2558, %int2_2939, %int0_2940, %2559, %int1_2941 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2560, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_2942 = torch.constant.int 0 - %2561 = torch.aten.clone %2560, %int0_2942 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2561, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_2943 = torch.constant.int 1 - %2562 = torch.aten.size.int %2557, %int1_2943 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_2944 = torch.constant.int 32 - %2563 = torch.aten.mul.int %2562, %int32_2944 : !torch.int, !torch.int -> !torch.int - %int4_2945 = torch.constant.int 4 - %int8_2946 = torch.constant.int 8 - %int128_2947 = torch.constant.int 128 - %2564 = torch.prim.ListConstruct %int4_2945, %2563, %int8_2946, %int128_2947 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2565 = torch.aten._unsafe_view %2561, %2564 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2565, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_2948 = torch.constant.int 0 - %int0_2949 = torch.constant.int 0 - %int9223372036854775807_2950 = torch.constant.int 9223372036854775807 - %int1_2951 = torch.constant.int 1 - %2566 = torch.aten.slice.Tensor %2565, %int0_2948, %int0_2949, %int9223372036854775807_2950, %int1_2951 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2566, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_2952 = torch.constant.int 0 - %int0_2953 = torch.constant.int 0 - %int9223372036854775807_2954 = torch.constant.int 9223372036854775807 - %int1_2955 = torch.constant.int 1 - %2567 = torch.aten.slice.Tensor %2555, %int0_2952, %int0_2953, %int9223372036854775807_2954, %int1_2955 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2567, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_2956 = torch.constant.int 1 - %int0_2957 = torch.constant.int 0 - %int9223372036854775807_2958 = torch.constant.int 9223372036854775807 - %int1_2959 = torch.constant.int 1 - %2568 = torch.aten.slice.Tensor %2567, %int1_2956, %int0_2957, %int9223372036854775807_2958, %int1_2959 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2568, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_2960 = torch.constant.int 2 - %int1_2961 = torch.constant.int 1 - %2569 = torch.aten.select.int %2568, %int2_2960, %int1_2961 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2569, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_2962 = torch.constant.int 2 - %int0_2963 = torch.constant.int 0 - %int1_2964 = torch.constant.int 1 - %2570 = torch.aten.slice.Tensor %2569, %int2_2962, %int0_2963, %2559, %int1_2964 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2570, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_2965 = torch.constant.int 0 - %2571 = torch.aten.clone %2570, %int0_2965 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2571, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_2966 = torch.constant.int 1 - %2572 = torch.aten.size.int %2568, %int1_2966 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_2967 = torch.constant.int 32 - %2573 = torch.aten.mul.int %2572, %int32_2967 : !torch.int, !torch.int -> !torch.int + %int2_2926 = torch.constant.int 2 + %int8_2927 = torch.constant.int 8 + %int32_2928 = torch.constant.int 32 + %int128_2929 = torch.constant.int 128 + %2995 = torch.prim.ListConstruct %456, %int32_2925, %int2_2926, %int8_2927, %int32_2928, %int128_2929 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %2996 = torch.aten.view %2994, %2995 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %2996, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_2930 = torch.constant.int 2097152 + %2997 = torch.prim.ListConstruct %456, %int2097152_2930 : (!torch.int, !torch.int) -> !torch.list + %2998 = torch.aten.view %2996, %2997 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %2998, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_2931 = torch.constant.none + %2999 = torch.aten.clone %162, %none_2931 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3000 = torch.aten.detach %2999 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3001 = torch.aten.detach %3000 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3002 = torch.aten.detach %3001 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_2932 = torch.constant.none + %3003 = torch.aten.clone %163, %none_2932 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3004 = torch.aten.detach %3003 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3005 = torch.aten.detach %3004 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3006 = torch.aten.detach %3005 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_2933 = torch.constant.none + %3007 = torch.aten.clone %164, %none_2933 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3008 = torch.aten.detach %3007 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3009 = torch.aten.detach %3008 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3010 = torch.aten.detach %3009 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_2934 = torch.constant.int 32 + %int2_2935 = torch.constant.int 2 + %int8_2936 = torch.constant.int 8 + %int32_2937 = torch.constant.int 32 + %int128_2938 = torch.constant.int 128 + %3011 = torch.prim.ListConstruct %456, %int32_2934, %int2_2935, %int8_2936, %int32_2937, %int128_2938 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3012 = torch.aten.view %2998, %3011 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3012, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %3013 = torch_c.to_builtin_tensor %3012 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %3014 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_2939 = tensor.cast %3014 : tensor<4x?xi64> to tensor + %3015 = torch_c.to_builtin_tensor %3002 : !torch.vtensor<[],si64> -> tensor + %3016 = torch_c.to_builtin_tensor %3006 : !torch.vtensor<[],si64> -> tensor + %3017 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%3013, %cast_2939, %3015, %3016) : (tensor, tensor, tensor, tensor) -> tensor + %cast_2940 = tensor.cast %3017 : tensor to tensor<4x?x8x32x128xf16> + %3018 = torch_c.from_builtin_tensor %cast_2940 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %3018, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %3019 = torch_c.to_builtin_tensor %3012 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %3020 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_2941 = tensor.cast %3020 : tensor<4x?xi64> to tensor + %3021 = torch_c.to_builtin_tensor %3002 : !torch.vtensor<[],si64> -> tensor + %3022 = torch_c.to_builtin_tensor %3010 : !torch.vtensor<[],si64> -> tensor + %3023 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%3019, %cast_2941, %3021, %3022) : (tensor, tensor, tensor, tensor) -> tensor + %cast_2942 = tensor.cast %3023 : tensor to tensor<4x?x8x32x128xf16> + %3024 = torch_c.from_builtin_tensor %cast_2942 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %3024, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_2943 = torch.constant.int 2 + %int3_2944 = torch.constant.int 3 + %3025 = torch.aten.transpose.int %3018, %int2_2943, %int3_2944 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3025, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_2945 = torch.constant.int 0 + %3026 = torch.aten.clone %3025, %int0_2945 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3026, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_2946 = torch.constant.int 4 + %int8_2947 = torch.constant.int 8 + %int128_2948 = torch.constant.int 128 + %3027 = torch.prim.ListConstruct %int4_2946, %457, %int8_2947, %int128_2948 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3028 = torch.aten._unsafe_view %3026, %3027 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3028, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_2949 = torch.constant.int 2 + %int3_2950 = torch.constant.int 3 + %3029 = torch.aten.transpose.int %3024, %int2_2949, %int3_2950 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3029, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_2951 = torch.constant.int 0 + %3030 = torch.aten.clone %3029, %int0_2951 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3030, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_2952 = torch.constant.int 4 + %int8_2953 = torch.constant.int 8 + %int128_2954 = torch.constant.int 128 + %3031 = torch.prim.ListConstruct %int4_2952, %457, %int8_2953, %int128_2954 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3032 = torch.aten._unsafe_view %3030, %3031 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3032, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_2955 = torch.constant.int -2 + %3033 = torch.aten.unsqueeze %3028, %int-2_2955 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3033, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_2956 = torch.constant.int 4 + %int8_2957 = torch.constant.int 8 + %int4_2958 = torch.constant.int 4 + %int128_2959 = torch.constant.int 128 + %3034 = torch.prim.ListConstruct %int4_2956, %457, %int8_2957, %int4_2958, %int128_2959 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_2960 = torch.constant.bool false + %3035 = torch.aten.expand %3033, %3034, %false_2960 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3035, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_2961 = torch.constant.int 0 + %3036 = torch.aten.clone %3035, %int0_2961 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3036, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_2962 = torch.constant.int 4 + %int32_2963 = torch.constant.int 32 + %int128_2964 = torch.constant.int 128 + %3037 = torch.prim.ListConstruct %int4_2962, %457, %int32_2963, %int128_2964 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3038 = torch.aten._unsafe_view %3036, %3037 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3038, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_2965 = torch.constant.int -2 + %3039 = torch.aten.unsqueeze %3032, %int-2_2965 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3039, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_2966 = torch.constant.int 4 + %int8_2967 = torch.constant.int 8 %int4_2968 = torch.constant.int 4 - %int8_2969 = torch.constant.int 8 - %int128_2970 = torch.constant.int 128 - %2574 = torch.prim.ListConstruct %int4_2968, %2573, %int8_2969, %int128_2970 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2575 = torch.aten._unsafe_view %2571, %2574 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2575, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int128_2969 = torch.constant.int 128 + %3040 = torch.prim.ListConstruct %int4_2966, %457, %int8_2967, %int4_2968, %int128_2969 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_2970 = torch.constant.bool false + %3041 = torch.aten.expand %3039, %3040, %false_2970 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3041, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int0_2971 = torch.constant.int 0 - %int0_2972 = torch.constant.int 0 - %int9223372036854775807_2973 = torch.constant.int 9223372036854775807 - %int1_2974 = torch.constant.int 1 - %2576 = torch.aten.slice.Tensor %2575, %int0_2971, %int0_2972, %int9223372036854775807_2973, %int1_2974 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2576, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_2975 = torch.constant.int -2 - %2577 = torch.aten.unsqueeze %2566, %int-2_2975 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2577, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_2976 = torch.constant.int 1 - %2578 = torch.aten.size.int %2565, %int1_2976 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_2977 = torch.constant.int 4 - %int8_2978 = torch.constant.int 8 - %int4_2979 = torch.constant.int 4 - %int128_2980 = torch.constant.int 128 - %2579 = torch.prim.ListConstruct %int4_2977, %2578, %int8_2978, %int4_2979, %int128_2980 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2981 = torch.constant.bool false - %2580 = torch.aten.expand %2577, %2579, %false_2981 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2580, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_2982 = torch.constant.int 0 - %2581 = torch.aten.clone %2580, %int0_2982 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2581, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_2983 = torch.constant.int 4 - %int32_2984 = torch.constant.int 32 - %int128_2985 = torch.constant.int 128 - %2582 = torch.prim.ListConstruct %int4_2983, %2578, %int32_2984, %int128_2985 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2583 = torch.aten._unsafe_view %2581, %2582 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2583, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_2986 = torch.constant.int -2 - %2584 = torch.aten.unsqueeze %2576, %int-2_2986 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2584, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %3042 = torch.aten.clone %3041, %int0_2971 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3042, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_2972 = torch.constant.int 4 + %int32_2973 = torch.constant.int 32 + %int128_2974 = torch.constant.int 128 + %3043 = torch.prim.ListConstruct %int4_2972, %457, %int32_2973, %int128_2974 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3044 = torch.aten._unsafe_view %3042, %3043 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3044, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_2975 = torch.constant.int 1 + %int2_2976 = torch.constant.int 2 + %3045 = torch.aten.transpose.int %2927, %int1_2975, %int2_2976 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_2977 = torch.constant.int 1 + %int2_2978 = torch.constant.int 2 + %3046 = torch.aten.transpose.int %3038, %int1_2977, %int2_2978 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3046, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_2979 = torch.constant.int 1 + %int2_2980 = torch.constant.int 2 + %3047 = torch.aten.transpose.int %3044, %int1_2979, %int2_2980 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3047, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_2981 = torch.constant.float 0.000000e+00 + %false_2982 = torch.constant.bool false + %none_2983 = torch.constant.none + %3048:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3045, %3046, %3047, %float0.000000e00_2981, %false_2982, %470, %none_2983) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_2984 = torch.constant.int 1 + %int2_2985 = torch.constant.int 2 + %3049 = torch.aten.transpose.int %3048#0, %int1_2984, %int2_2985 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_2986 = torch.constant.int 4 %int1_2987 = torch.constant.int 1 - %2585 = torch.aten.size.int %2575, %int1_2987 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_2988 = torch.constant.int 4 - %int8_2989 = torch.constant.int 8 - %int4_2990 = torch.constant.int 4 - %int128_2991 = torch.constant.int 128 - %2586 = torch.prim.ListConstruct %int4_2988, %2585, %int8_2989, %int4_2990, %int128_2991 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_2992 = torch.constant.bool false - %2587 = torch.aten.expand %2584, %2586, %false_2992 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2587, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_2993 = torch.constant.int 0 - %2588 = torch.aten.clone %2587, %int0_2993 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2588, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4096_2988 = torch.constant.int 4096 + %3050 = torch.prim.ListConstruct %int4_2986, %int1_2987, %int4096_2988 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3051 = torch.aten.view %3049, %3050 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_2989 = torch.constant.int -2 + %int-1_2990 = torch.constant.int -1 + %3052 = torch.aten.transpose.int %165, %int-2_2989, %int-1_2990 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_2991 = torch.constant.int 5 + %3053 = torch.prims.convert_element_type %3052, %int5_2991 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_2992 = torch.constant.int 4 + %int4096_2993 = torch.constant.int 4096 + %3054 = torch.prim.ListConstruct %int4_2992, %int4096_2993 : (!torch.int, !torch.int) -> !torch.list + %3055 = torch.aten.view %3051, %3054 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3056 = torch.aten.mm %3055, %3053 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_2994 = torch.constant.int 4 - %int32_2995 = torch.constant.int 32 - %int128_2996 = torch.constant.int 128 - %2589 = torch.prim.ListConstruct %int4_2994, %2585, %int32_2995, %int128_2996 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2590 = torch.aten._unsafe_view %2588, %2589 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2590, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_2995 = torch.constant.int 1 + %int4096_2996 = torch.constant.int 4096 + %3057 = torch.prim.ListConstruct %int4_2994, %int1_2995, %int4096_2996 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3058 = torch.aten.view %3056, %3057 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_2997 = torch.constant.int 1 - %int2_2998 = torch.constant.int 2 - %2591 = torch.aten.transpose.int %2471, %int1_2997, %int2_2998 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_2999 = torch.constant.int 1 - %int2_3000 = torch.constant.int 2 - %2592 = torch.aten.transpose.int %2583, %int1_2999, %int2_3000 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2592, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_3001 = torch.constant.int 1 - %int2_3002 = torch.constant.int 2 - %2593 = torch.aten.transpose.int %2590, %int1_3001, %int2_3002 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2593, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_3003 = torch.constant.float 0.000000e+00 - %false_3004 = torch.constant.bool false - %none_3005 = torch.constant.none - %2594:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2591, %2592, %2593, %float0.000000e00_3003, %false_3004, %368, %none_3005) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_3006 = torch.constant.int 1 - %int2_3007 = torch.constant.int 2 - %2595 = torch.aten.transpose.int %2594#0, %int1_3006, %int2_3007 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_3008 = torch.constant.int 4 - %int1_3009 = torch.constant.int 1 - %int4096_3010 = torch.constant.int 4096 - %2596 = torch.prim.ListConstruct %int4_3008, %int1_3009, %int4096_3010 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2597 = torch.aten.view %2595, %2596 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_3011 = torch.constant.int -2 - %int-1_3012 = torch.constant.int -1 - %2598 = torch.aten.transpose.int %117, %int-2_3011, %int-1_3012 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3013 = torch.constant.int 4 - %int4096_3014 = torch.constant.int 4096 - %2599 = torch.prim.ListConstruct %int4_3013, %int4096_3014 : (!torch.int, !torch.int) -> !torch.list - %2600 = torch.aten.view %2597, %2599 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2601 = torch.aten.mm %2600, %2598 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_3015 = torch.constant.int 4 - %int1_3016 = torch.constant.int 1 - %int4096_3017 = torch.constant.int 4096 - %2602 = torch.prim.ListConstruct %int4_3015, %int1_3016, %int4096_3017 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2603 = torch.aten.view %2601, %2602 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_3018 = torch.constant.int 1 - %2604 = torch.aten.add.Tensor %2431, %2603, %int1_3018 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_3019 = torch.constant.int 6 - %2605 = torch.prims.convert_element_type %2604, %int6_3019 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_3020 = torch.constant.int 2 - %2606 = torch.aten.pow.Tensor_Scalar %2605, %int2_3020 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_3021 = torch.constant.int -1 - %2607 = torch.prim.ListConstruct %int-1_3021 : (!torch.int) -> !torch.list - %true_3022 = torch.constant.bool true - %none_3023 = torch.constant.none - %2608 = torch.aten.mean.dim %2606, %2607, %true_3022, %none_3023 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_3024 = torch.constant.float 9.9999997473787516E-6 - %int1_3025 = torch.constant.int 1 - %2609 = torch.aten.add.Scalar %2608, %float9.999990e-06_3024, %int1_3025 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %2610 = torch.aten.rsqrt %2609 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %2611 = torch.aten.mul.Tensor %2605, %2610 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_3026 = torch.constant.int 5 - %2612 = torch.prims.convert_element_type %2611, %int5_3026 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %2613 = torch.aten.mul.Tensor %118, %2612 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_3027 = torch.constant.int 5 - %2614 = torch.prims.convert_element_type %2613, %int5_3027 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_3028 = torch.constant.int -2 - %int-1_3029 = torch.constant.int -1 - %2615 = torch.aten.transpose.int %119, %int-2_3028, %int-1_3029 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3030 = torch.constant.int 4 - %int4096_3031 = torch.constant.int 4096 - %2616 = torch.prim.ListConstruct %int4_3030, %int4096_3031 : (!torch.int, !torch.int) -> !torch.list - %2617 = torch.aten.view %2614, %2616 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2618 = torch.aten.mm %2617, %2615 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_3032 = torch.constant.int 4 - %int1_3033 = torch.constant.int 1 - %int14336_3034 = torch.constant.int 14336 - %2619 = torch.prim.ListConstruct %int4_3032, %int1_3033, %int14336_3034 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2620 = torch.aten.view %2618, %2619 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %2621 = torch.aten.silu %2620 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_3035 = torch.constant.int -2 - %int-1_3036 = torch.constant.int -1 - %2622 = torch.aten.transpose.int %120, %int-2_3035, %int-1_3036 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3037 = torch.constant.int 4 - %int4096_3038 = torch.constant.int 4096 - %2623 = torch.prim.ListConstruct %int4_3037, %int4096_3038 : (!torch.int, !torch.int) -> !torch.list - %2624 = torch.aten.view %2614, %2623 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2625 = torch.aten.mm %2624, %2622 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_3039 = torch.constant.int 4 - %int1_3040 = torch.constant.int 1 - %int14336_3041 = torch.constant.int 14336 - %2626 = torch.prim.ListConstruct %int4_3039, %int1_3040, %int14336_3041 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2627 = torch.aten.view %2625, %2626 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %2628 = torch.aten.mul.Tensor %2621, %2627 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_3042 = torch.constant.int -2 - %int-1_3043 = torch.constant.int -1 - %2629 = torch.aten.transpose.int %121, %int-2_3042, %int-1_3043 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %3059 = torch.aten.add.Tensor %2880, %3058, %int1_2997 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_2998 = torch.constant.int 6 + %3060 = torch.prims.convert_element_type %3059, %int6_2998 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_2999 = torch.constant.int 2 + %3061 = torch.aten.pow.Tensor_Scalar %3060, %int2_2999 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_3000 = torch.constant.int -1 + %3062 = torch.prim.ListConstruct %int-1_3000 : (!torch.int) -> !torch.list + %true_3001 = torch.constant.bool true + %none_3002 = torch.constant.none + %3063 = torch.aten.mean.dim %3061, %3062, %true_3001, %none_3002 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_3003 = torch.constant.float 9.9999997473787516E-6 + %int1_3004 = torch.constant.int 1 + %3064 = torch.aten.add.Scalar %3063, %float9.999990e-06_3003, %int1_3004 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %3065 = torch.aten.rsqrt %3064 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %3066 = torch.aten.mul.Tensor %3060, %3065 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_3005 = torch.constant.int 5 + %3067 = torch.prims.convert_element_type %3066, %int5_3005 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %3068 = torch.aten.mul.Tensor %166, %3067 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_3006 = torch.constant.int 5 + %3069 = torch.prims.convert_element_type %3068, %int5_3006 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_3007 = torch.constant.int -2 + %int-1_3008 = torch.constant.int -1 + %3070 = torch.aten.transpose.int %167, %int-2_3007, %int-1_3008 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3009 = torch.constant.int 5 + %3071 = torch.prims.convert_element_type %3070, %int5_3009 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_3010 = torch.constant.int 4 + %int4096_3011 = torch.constant.int 4096 + %3072 = torch.prim.ListConstruct %int4_3010, %int4096_3011 : (!torch.int, !torch.int) -> !torch.list + %3073 = torch.aten.view %3069, %3072 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3074 = torch.aten.mm %3073, %3071 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_3012 = torch.constant.int 4 + %int1_3013 = torch.constant.int 1 + %int14336_3014 = torch.constant.int 14336 + %3075 = torch.prim.ListConstruct %int4_3012, %int1_3013, %int14336_3014 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3076 = torch.aten.view %3074, %3075 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %3077 = torch.aten.silu %3076 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_3015 = torch.constant.int -2 + %int-1_3016 = torch.constant.int -1 + %3078 = torch.aten.transpose.int %168, %int-2_3015, %int-1_3016 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3017 = torch.constant.int 5 + %3079 = torch.prims.convert_element_type %3078, %int5_3017 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_3018 = torch.constant.int 4 + %int4096_3019 = torch.constant.int 4096 + %3080 = torch.prim.ListConstruct %int4_3018, %int4096_3019 : (!torch.int, !torch.int) -> !torch.list + %3081 = torch.aten.view %3069, %3080 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3082 = torch.aten.mm %3081, %3079 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_3020 = torch.constant.int 4 + %int1_3021 = torch.constant.int 1 + %int14336_3022 = torch.constant.int 14336 + %3083 = torch.prim.ListConstruct %int4_3020, %int1_3021, %int14336_3022 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3084 = torch.aten.view %3082, %3083 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %3085 = torch.aten.mul.Tensor %3077, %3084 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_3023 = torch.constant.int -2 + %int-1_3024 = torch.constant.int -1 + %3086 = torch.aten.transpose.int %169, %int-2_3023, %int-1_3024 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_3025 = torch.constant.int 5 + %3087 = torch.prims.convert_element_type %3086, %int5_3025 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_3026 = torch.constant.int 4 + %int14336_3027 = torch.constant.int 14336 + %3088 = torch.prim.ListConstruct %int4_3026, %int14336_3027 : (!torch.int, !torch.int) -> !torch.list + %3089 = torch.aten.view %3085, %3088 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %3090 = torch.aten.mm %3089, %3087 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_3028 = torch.constant.int 4 + %int1_3029 = torch.constant.int 1 + %int4096_3030 = torch.constant.int 4096 + %3091 = torch.prim.ListConstruct %int4_3028, %int1_3029, %int4096_3030 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3092 = torch.aten.view %3090, %3091 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_3031 = torch.constant.int 1 + %3093 = torch.aten.add.Tensor %3059, %3092, %int1_3031 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_3032 = torch.constant.int 6 + %3094 = torch.prims.convert_element_type %3093, %int6_3032 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_3033 = torch.constant.int 2 + %3095 = torch.aten.pow.Tensor_Scalar %3094, %int2_3033 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_3034 = torch.constant.int -1 + %3096 = torch.prim.ListConstruct %int-1_3034 : (!torch.int) -> !torch.list + %true_3035 = torch.constant.bool true + %none_3036 = torch.constant.none + %3097 = torch.aten.mean.dim %3095, %3096, %true_3035, %none_3036 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_3037 = torch.constant.float 9.9999997473787516E-6 + %int1_3038 = torch.constant.int 1 + %3098 = torch.aten.add.Scalar %3097, %float9.999990e-06_3037, %int1_3038 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %3099 = torch.aten.rsqrt %3098 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %3100 = torch.aten.mul.Tensor %3094, %3099 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_3039 = torch.constant.int 5 + %3101 = torch.prims.convert_element_type %3100, %int5_3039 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %3102 = torch.aten.mul.Tensor %170, %3101 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_3040 = torch.constant.int 5 + %3103 = torch.prims.convert_element_type %3102, %int5_3040 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_3041 = torch.constant.int -2 + %int-1_3042 = torch.constant.int -1 + %3104 = torch.aten.transpose.int %171, %int-2_3041, %int-1_3042 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_3043 = torch.constant.int 5 + %3105 = torch.prims.convert_element_type %3104, %int5_3043 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> %int4_3044 = torch.constant.int 4 - %int14336_3045 = torch.constant.int 14336 - %2630 = torch.prim.ListConstruct %int4_3044, %int14336_3045 : (!torch.int, !torch.int) -> !torch.list - %2631 = torch.aten.view %2628, %2630 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %2632 = torch.aten.mm %2631, %2629 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4096_3045 = torch.constant.int 4096 + %3106 = torch.prim.ListConstruct %int4_3044, %int4096_3045 : (!torch.int, !torch.int) -> !torch.list + %3107 = torch.aten.view %3103, %3106 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3108 = torch.aten.mm %3107, %3105 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_3046 = torch.constant.int 4 %int1_3047 = torch.constant.int 1 %int4096_3048 = torch.constant.int 4096 - %2633 = torch.prim.ListConstruct %int4_3046, %int1_3047, %int4096_3048 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2634 = torch.aten.view %2632, %2633 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_3049 = torch.constant.int 1 - %2635 = torch.aten.add.Tensor %2604, %2634, %int1_3049 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_3050 = torch.constant.int 6 - %2636 = torch.prims.convert_element_type %2635, %int6_3050 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_3051 = torch.constant.int 2 - %2637 = torch.aten.pow.Tensor_Scalar %2636, %int2_3051 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_3052 = torch.constant.int -1 - %2638 = torch.prim.ListConstruct %int-1_3052 : (!torch.int) -> !torch.list - %true_3053 = torch.constant.bool true - %none_3054 = torch.constant.none - %2639 = torch.aten.mean.dim %2637, %2638, %true_3053, %none_3054 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_3055 = torch.constant.float 9.9999997473787516E-6 - %int1_3056 = torch.constant.int 1 - %2640 = torch.aten.add.Scalar %2639, %float9.999990e-06_3055, %int1_3056 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %2641 = torch.aten.rsqrt %2640 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %2642 = torch.aten.mul.Tensor %2636, %2641 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_3057 = torch.constant.int 5 - %2643 = torch.prims.convert_element_type %2642, %int5_3057 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %2644 = torch.aten.mul.Tensor %122, %2643 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_3058 = torch.constant.int 5 - %2645 = torch.prims.convert_element_type %2644, %int5_3058 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_3059 = torch.constant.int -2 - %int-1_3060 = torch.constant.int -1 - %2646 = torch.aten.transpose.int %123, %int-2_3059, %int-1_3060 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3061 = torch.constant.int 4 - %int4096_3062 = torch.constant.int 4096 - %2647 = torch.prim.ListConstruct %int4_3061, %int4096_3062 : (!torch.int, !torch.int) -> !torch.list - %2648 = torch.aten.view %2645, %2647 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2649 = torch.aten.mm %2648, %2646 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_3063 = torch.constant.int 4 - %int1_3064 = torch.constant.int 1 - %int4096_3065 = torch.constant.int 4096 - %2650 = torch.prim.ListConstruct %int4_3063, %int1_3064, %int4096_3065 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2651 = torch.aten.view %2649, %2650 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_3066 = torch.constant.int -2 - %int-1_3067 = torch.constant.int -1 - %2652 = torch.aten.transpose.int %124, %int-2_3066, %int-1_3067 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3068 = torch.constant.int 4 - %int4096_3069 = torch.constant.int 4096 - %2653 = torch.prim.ListConstruct %int4_3068, %int4096_3069 : (!torch.int, !torch.int) -> !torch.list - %2654 = torch.aten.view %2645, %2653 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2655 = torch.aten.mm %2654, %2652 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_3070 = torch.constant.int 4 - %int1_3071 = torch.constant.int 1 - %int1024_3072 = torch.constant.int 1024 - %2656 = torch.prim.ListConstruct %int4_3070, %int1_3071, %int1024_3072 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2657 = torch.aten.view %2655, %2656 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_3073 = torch.constant.int -2 - %int-1_3074 = torch.constant.int -1 - %2658 = torch.aten.transpose.int %125, %int-2_3073, %int-1_3074 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3075 = torch.constant.int 4 - %int4096_3076 = torch.constant.int 4096 - %2659 = torch.prim.ListConstruct %int4_3075, %int4096_3076 : (!torch.int, !torch.int) -> !torch.list - %2660 = torch.aten.view %2645, %2659 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2661 = torch.aten.mm %2660, %2658 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_3077 = torch.constant.int 4 - %int1_3078 = torch.constant.int 1 - %int1024_3079 = torch.constant.int 1024 - %2662 = torch.prim.ListConstruct %int4_3077, %int1_3078, %int1024_3079 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2663 = torch.aten.view %2661, %2662 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_3080 = torch.constant.int 4 - %int1_3081 = torch.constant.int 1 - %int32_3082 = torch.constant.int 32 - %int128_3083 = torch.constant.int 128 - %2664 = torch.prim.ListConstruct %int4_3080, %int1_3081, %int32_3082, %int128_3083 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2665 = torch.aten.view %2651, %2664 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_3084 = torch.constant.int 4 - %int1_3085 = torch.constant.int 1 - %int8_3086 = torch.constant.int 8 - %int128_3087 = torch.constant.int 128 - %2666 = torch.prim.ListConstruct %int4_3084, %int1_3085, %int8_3086, %int128_3087 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2667 = torch.aten.view %2657, %2666 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_3088 = torch.constant.int 4 + %3109 = torch.prim.ListConstruct %int4_3046, %int1_3047, %int4096_3048 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3110 = torch.aten.view %3108, %3109 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_3049 = torch.constant.int -2 + %int-1_3050 = torch.constant.int -1 + %3111 = torch.aten.transpose.int %172, %int-2_3049, %int-1_3050 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_3051 = torch.constant.int 5 + %3112 = torch.prims.convert_element_type %3111, %int5_3051 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_3052 = torch.constant.int 4 + %int4096_3053 = torch.constant.int 4096 + %3113 = torch.prim.ListConstruct %int4_3052, %int4096_3053 : (!torch.int, !torch.int) -> !torch.list + %3114 = torch.aten.view %3103, %3113 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3115 = torch.aten.mm %3114, %3112 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_3054 = torch.constant.int 4 + %int1_3055 = torch.constant.int 1 + %int1024_3056 = torch.constant.int 1024 + %3116 = torch.prim.ListConstruct %int4_3054, %int1_3055, %int1024_3056 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3117 = torch.aten.view %3115, %3116 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_3057 = torch.constant.int -2 + %int-1_3058 = torch.constant.int -1 + %3118 = torch.aten.transpose.int %173, %int-2_3057, %int-1_3058 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_3059 = torch.constant.int 5 + %3119 = torch.prims.convert_element_type %3118, %int5_3059 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_3060 = torch.constant.int 4 + %int4096_3061 = torch.constant.int 4096 + %3120 = torch.prim.ListConstruct %int4_3060, %int4096_3061 : (!torch.int, !torch.int) -> !torch.list + %3121 = torch.aten.view %3103, %3120 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3122 = torch.aten.mm %3121, %3119 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_3062 = torch.constant.int 4 + %int1_3063 = torch.constant.int 1 + %int1024_3064 = torch.constant.int 1024 + %3123 = torch.prim.ListConstruct %int4_3062, %int1_3063, %int1024_3064 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3124 = torch.aten.view %3122, %3123 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_3065 = torch.constant.int 4 + %int1_3066 = torch.constant.int 1 + %int32_3067 = torch.constant.int 32 + %int128_3068 = torch.constant.int 128 + %3125 = torch.prim.ListConstruct %int4_3065, %int1_3066, %int32_3067, %int128_3068 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3126 = torch.aten.view %3110, %3125 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_3069 = torch.constant.int 4 + %int1_3070 = torch.constant.int 1 + %int8_3071 = torch.constant.int 8 + %int128_3072 = torch.constant.int 128 + %3127 = torch.prim.ListConstruct %int4_3069, %int1_3070, %int8_3071, %int128_3072 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3128 = torch.aten.view %3117, %3127 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_3073 = torch.constant.int 4 + %int1_3074 = torch.constant.int 1 + %int8_3075 = torch.constant.int 8 + %int128_3076 = torch.constant.int 128 + %3129 = torch.prim.ListConstruct %int4_3073, %int1_3074, %int8_3075, %int128_3076 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3130 = torch.aten.view %3124, %3129 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_3077 = torch.constant.int 1 + %int2_3078 = torch.constant.int 2 + %3131 = torch.aten.transpose.int %3126, %int1_3077, %int2_3078 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %3132 = torch.aten.mul.Tensor %3131, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_3079 = torch.constant.int 3 + %int0_3080 = torch.constant.int 0 + %int64_3081 = torch.constant.int 64 + %int1_3082 = torch.constant.int 1 + %3133 = torch.aten.slice.Tensor %3131, %int3_3079, %int0_3080, %int64_3081, %int1_3082 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_3083 = torch.constant.int 3 + %int64_3084 = torch.constant.int 64 + %int9223372036854775807_3085 = torch.constant.int 9223372036854775807 + %int1_3086 = torch.constant.int 1 + %3134 = torch.aten.slice.Tensor %3131, %int3_3083, %int64_3084, %int9223372036854775807_3085, %int1_3086 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %3135 = torch.aten.neg %3134 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %3136 = torch.prim.ListConstruct %3135, %3133 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_3087 = torch.constant.int -1 + %3137 = torch.aten.cat %3136, %int-1_3087 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %3138 = torch.aten.mul.Tensor %3137, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_3088 = torch.constant.int 1 + %3139 = torch.aten.add.Tensor %3132, %3138, %int1_3088 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_3089 = torch.constant.int 1 - %int8_3090 = torch.constant.int 8 - %int128_3091 = torch.constant.int 128 - %2668 = torch.prim.ListConstruct %int4_3088, %int1_3089, %int8_3090, %int128_3091 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2669 = torch.aten.view %2663, %2668 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_3092 = torch.constant.int 6 - %2670 = torch.prims.convert_element_type %2665, %int6_3092 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %2671 = torch_c.to_builtin_tensor %2670 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %2672 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %2673 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%2671, %2672) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %2674 = torch_c.from_builtin_tensor %2673 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_3093 = torch.constant.int 5 - %2675 = torch.prims.convert_element_type %2674, %int5_3093 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_3094 = torch.constant.int 6 - %2676 = torch.prims.convert_element_type %2667, %int6_3094 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %2677 = torch_c.to_builtin_tensor %2676 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %2678 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %2679 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%2677, %2678) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %2680 = torch_c.from_builtin_tensor %2679 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_3095 = torch.constant.int 5 - %2681 = torch.prims.convert_element_type %2680, %int5_3095 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_3096 = torch.constant.int 32 - %2682 = torch.aten.floor_divide.Scalar %arg2, %int32_3096 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_3097 = torch.constant.int 1 - %2683 = torch.aten.unsqueeze %2682, %int1_3097 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3098 = torch.constant.int 1 - %false_3099 = torch.constant.bool false - %2684 = torch.aten.gather %arg3, %int1_3098, %2683, %false_3099 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_3100 = torch.constant.int 32 - %2685 = torch.aten.remainder.Scalar %arg2, %int32_3100 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_3101 = torch.constant.int 1 - %2686 = torch.aten.unsqueeze %2685, %int1_3101 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_3102 = torch.constant.none - %2687 = torch.aten.clone %126, %none_3102 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_3103 = torch.constant.int 0 - %2688 = torch.aten.unsqueeze %2687, %int0_3103 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_3104 = torch.constant.int 4 - %int1_3105 = torch.constant.int 1 - %2689 = torch.prim.ListConstruct %int4_3104, %int1_3105 : (!torch.int, !torch.int) -> !torch.list + %int2_3090 = torch.constant.int 2 + %3140 = torch.aten.transpose.int %3139, %int1_3089, %int2_3090 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_3091 = torch.constant.int 1 + %int2_3092 = torch.constant.int 2 + %3141 = torch.aten.transpose.int %3128, %int1_3091, %int2_3092 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %3142 = torch.aten.mul.Tensor %3141, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_3093 = torch.constant.int 3 + %int0_3094 = torch.constant.int 0 + %int64_3095 = torch.constant.int 64 + %int1_3096 = torch.constant.int 1 + %3143 = torch.aten.slice.Tensor %3141, %int3_3093, %int0_3094, %int64_3095, %int1_3096 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_3097 = torch.constant.int 3 + %int64_3098 = torch.constant.int 64 + %int9223372036854775807_3099 = torch.constant.int 9223372036854775807 + %int1_3100 = torch.constant.int 1 + %3144 = torch.aten.slice.Tensor %3141, %int3_3097, %int64_3098, %int9223372036854775807_3099, %int1_3100 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %3145 = torch.aten.neg %3144 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %3146 = torch.prim.ListConstruct %3145, %3143 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_3101 = torch.constant.int -1 + %3147 = torch.aten.cat %3146, %int-1_3101 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %3148 = torch.aten.mul.Tensor %3147, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_3102 = torch.constant.int 1 + %3149 = torch.aten.add.Tensor %3142, %3148, %int1_3102 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_3103 = torch.constant.int 1 + %int2_3104 = torch.constant.int 2 + %3150 = torch.aten.transpose.int %3149, %int1_3103, %int2_3104 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_3105 = torch.constant.int 32 + %3151 = torch.aten.floor_divide.Scalar %arg2, %int32_3105 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> %int1_3106 = torch.constant.int 1 + %3152 = torch.aten.unsqueeze %3151, %int1_3106 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> %int1_3107 = torch.constant.int 1 - %2690 = torch.prim.ListConstruct %int1_3106, %int1_3107 : (!torch.int, !torch.int) -> !torch.list - %int4_3108 = torch.constant.int 4 - %int0_3109 = torch.constant.int 0 - %cpu_3110 = torch.constant.device "cpu" - %false_3111 = torch.constant.bool false - %2691 = torch.aten.empty_strided %2689, %2690, %int4_3108, %int0_3109, %cpu_3110, %false_3111 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int11 = torch.constant.int 11 - %2692 = torch.aten.fill.Scalar %2691, %int11 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_3112 = torch.constant.int 4 - %int1_3113 = torch.constant.int 1 - %2693 = torch.prim.ListConstruct %int4_3112, %int1_3113 : (!torch.int, !torch.int) -> !torch.list - %2694 = torch.aten.repeat %2688, %2693 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_3114 = torch.constant.int 32 - %2695 = torch.aten.mul.Scalar %2684, %int32_3114 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %false_3108 = torch.constant.bool false + %3153 = torch.aten.gather %arg3, %int1_3107, %3152, %false_3108 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_3109 = torch.constant.int 4 + %int1_3110 = torch.constant.int 1 + %int1_3111 = torch.constant.int 1 + %3154 = torch.prim.ListConstruct %int4_3109, %int1_3110, %int1_3111 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3155 = torch.aten.view %3153, %3154 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_3112 = torch.constant.int 32 + %3156 = torch.aten.remainder.Scalar %arg2, %int32_3112 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_3113 = torch.constant.int 4 + %int1_3114 = torch.constant.int 1 %int1_3115 = torch.constant.int 1 - %2696 = torch.aten.add.Tensor %2695, %2692, %int1_3115 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_3116 = torch.constant.int 2 - %2697 = torch.aten.mul.Scalar %2696, %int2_3116 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3117 = torch.constant.int 1 - %2698 = torch.aten.add.Tensor %2697, %2694, %int1_3117 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_3118 = torch.constant.int 32 - %2699 = torch.aten.mul.Scalar %2698, %int32_3118 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3119 = torch.constant.int 1 - %2700 = torch.aten.add.Tensor %2699, %2686, %int1_3119 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_3120 = torch.constant.int 32 - %int2_3121 = torch.constant.int 2 - %int32_3122 = torch.constant.int 32 + %3157 = torch.prim.ListConstruct %int4_3113, %int1_3114, %int1_3115 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3158 = torch.aten.view %3156, %3157 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_3116 = torch.constant.int 8 + %none_3117 = torch.constant.none + %none_3118 = torch.constant.none + %cpu_3119 = torch.constant.device "cpu" + %false_3120 = torch.constant.bool false + %3159 = torch.aten.arange %int8_3116, %none_3117, %none_3118, %cpu_3119, %false_3120 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_3121 = torch.constant.int 1 + %int1_3122 = torch.constant.int 1 %int8_3123 = torch.constant.int 8 - %int128_3124 = torch.constant.int 128 - %2701 = torch.prim.ListConstruct %437, %int32_3120, %int2_3121, %int32_3122, %int8_3123, %int128_3124 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2702 = torch.aten.view %2538, %2701 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2702, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_3125 = torch.constant.int 32 - %2703 = torch.aten.mul.int %437, %int32_3125 : !torch.int, !torch.int -> !torch.int - %int2_3126 = torch.constant.int 2 - %2704 = torch.aten.mul.int %2703, %int2_3126 : !torch.int, !torch.int -> !torch.int - %int32_3127 = torch.constant.int 32 - %2705 = torch.aten.mul.int %2704, %int32_3127 : !torch.int, !torch.int -> !torch.int - %int8_3128 = torch.constant.int 8 - %int128_3129 = torch.constant.int 128 - %2706 = torch.prim.ListConstruct %2705, %int8_3128, %int128_3129 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2707 = torch.aten.view %2702, %2706 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2707, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %2708 = torch.prim.ListConstruct %2700 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_3130 = torch.constant.bool false - %2709 = torch.aten.index_put %2707, %2708, %2681, %false_3130 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2709, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_3131 = torch.constant.int 32 - %int2_3132 = torch.constant.int 2 - %int32_3133 = torch.constant.int 32 - %int8_3134 = torch.constant.int 8 - %int128_3135 = torch.constant.int 128 - %2710 = torch.prim.ListConstruct %437, %int32_3131, %int2_3132, %int32_3133, %int8_3134, %int128_3135 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2711 = torch.aten.view %2709, %2710 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2711, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3136 = torch.constant.int 2097152 - %2712 = torch.prim.ListConstruct %437, %int2097152_3136 : (!torch.int, !torch.int) -> !torch.list - %2713 = torch.aten.view %2711, %2712 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2713, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %3160 = torch.prim.ListConstruct %int1_3121, %int1_3122, %int8_3123 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3161 = torch.aten.view %3159, %3160 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_3124 = torch.constant.none + %3162 = torch.aten.clone %174, %none_3124 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3163 = torch.aten.detach %3162 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3164 = torch.aten.detach %3163 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3165 = torch.aten.detach %3164 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_3125 = torch.constant.int 1 + %int1_3126 = torch.constant.int 1 + %int1_3127 = torch.constant.int 1 + %3166 = torch.prim.ListConstruct %int1_3125, %int1_3126, %int1_3127 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3167 = torch.aten.view %3165, %3166 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_3128 = torch.constant.int 32 + %3168 = torch.aten.mul.Scalar %3155, %int32_3128 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int12 = torch.constant.int 12 + %int1_3129 = torch.constant.int 1 + %3169 = torch.aten.add.Scalar %3168, %int12, %int1_3129 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_3130 = torch.constant.int 2 + %3170 = torch.aten.mul.Scalar %3169, %int2_3130 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_3131 = torch.constant.int 1 + %3171 = torch.aten.add.Tensor %3170, %3167, %int1_3131 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_3132 = torch.constant.int 8 + %3172 = torch.aten.mul.Scalar %3171, %int8_3132 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_3133 = torch.constant.int 1 + %3173 = torch.aten.add.Tensor %3172, %3161, %int1_3133 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_3134 = torch.constant.int 32 + %3174 = torch.aten.mul.Scalar %3173, %int32_3134 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_3135 = torch.constant.int 1 + %3175 = torch.aten.add.Tensor %3174, %3158, %int1_3135 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_3136 = torch.constant.int 5 + %3176 = torch.prims.convert_element_type %3150, %int5_3136 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> %int32_3137 = torch.constant.int 32 %int2_3138 = torch.constant.int 2 - %int32_3139 = torch.constant.int 32 - %int8_3140 = torch.constant.int 8 + %int8_3139 = torch.constant.int 8 + %int32_3140 = torch.constant.int 32 %int128_3141 = torch.constant.int 128 - %2714 = torch.prim.ListConstruct %437, %int32_3137, %int2_3138, %int32_3139, %int8_3140, %int128_3141 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2715 = torch.aten.view %2713, %2714 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2715, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_3142 = torch.constant.int 8 - %int128_3143 = torch.constant.int 128 - %2716 = torch.prim.ListConstruct %2705, %int8_3142, %int128_3143 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2717 = torch.aten.view %2715, %2716 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2717, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> + %3177 = torch.prim.ListConstruct %456, %int32_3137, %int2_3138, %int8_3139, %int32_3140, %int128_3141 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3178 = torch.aten.view %2998, %3177 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3178, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_3142 = torch.constant.int 128 + %3179 = torch.prim.ListConstruct %596, %int128_3142 : (!torch.int, !torch.int) -> !torch.list + %3180 = torch.aten.view %3178, %3179 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3180, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %3181 = torch.prim.ListConstruct %3175 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_3143 = torch.constant.bool false + %3182 = torch.aten.index_put %3180, %3181, %3176, %false_3143 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3182, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> %int32_3144 = torch.constant.int 32 - %2718 = torch.aten.floor_divide.Scalar %arg2, %int32_3144 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_3145 = torch.constant.int 1 - %2719 = torch.aten.unsqueeze %2718, %int1_3145 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3146 = torch.constant.int 1 - %false_3147 = torch.constant.bool false - %2720 = torch.aten.gather %arg3, %int1_3146, %2719, %false_3147 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_3148 = torch.constant.int 32 - %2721 = torch.aten.remainder.Scalar %arg2, %int32_3148 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_3149 = torch.constant.int 1 - %2722 = torch.aten.unsqueeze %2721, %int1_3149 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_3150 = torch.constant.none - %2723 = torch.aten.clone %127, %none_3150 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_3151 = torch.constant.int 0 - %2724 = torch.aten.unsqueeze %2723, %int0_3151 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_3152 = torch.constant.int 4 - %int1_3153 = torch.constant.int 1 - %2725 = torch.prim.ListConstruct %int4_3152, %int1_3153 : (!torch.int, !torch.int) -> !torch.list - %int1_3154 = torch.constant.int 1 - %int1_3155 = torch.constant.int 1 - %2726 = torch.prim.ListConstruct %int1_3154, %int1_3155 : (!torch.int, !torch.int) -> !torch.list - %int4_3156 = torch.constant.int 4 - %int0_3157 = torch.constant.int 0 - %cpu_3158 = torch.constant.device "cpu" - %false_3159 = torch.constant.bool false - %2727 = torch.aten.empty_strided %2725, %2726, %int4_3156, %int0_3157, %cpu_3158, %false_3159 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int11_3160 = torch.constant.int 11 - %2728 = torch.aten.fill.Scalar %2727, %int11_3160 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_3161 = torch.constant.int 4 + %int2_3145 = torch.constant.int 2 + %int8_3146 = torch.constant.int 8 + %int32_3147 = torch.constant.int 32 + %int128_3148 = torch.constant.int 128 + %3183 = torch.prim.ListConstruct %456, %int32_3144, %int2_3145, %int8_3146, %int32_3147, %int128_3148 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3184 = torch.aten.view %3182, %3183 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3184, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_3149 = torch.constant.int 2097152 + %3185 = torch.prim.ListConstruct %456, %int2097152_3149 : (!torch.int, !torch.int) -> !torch.list + %3186 = torch.aten.view %3184, %3185 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %3186, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_3150 = torch.constant.int 32 + %int2_3151 = torch.constant.int 2 + %int8_3152 = torch.constant.int 8 + %int32_3153 = torch.constant.int 32 + %int128_3154 = torch.constant.int 128 + %3187 = torch.prim.ListConstruct %456, %int32_3150, %int2_3151, %int8_3152, %int32_3153, %int128_3154 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3188 = torch.aten.view %3186, %3187 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3188, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_3155 = torch.constant.int 128 + %3189 = torch.prim.ListConstruct %596, %int128_3155 : (!torch.int, !torch.int) -> !torch.list + %3190 = torch.aten.view %3188, %3189 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3190, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_3156 = torch.constant.none + %3191 = torch.aten.clone %175, %none_3156 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3192 = torch.aten.detach %3191 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3193 = torch.aten.detach %3192 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3194 = torch.aten.detach %3193 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_3157 = torch.constant.int 1 + %int1_3158 = torch.constant.int 1 + %int1_3159 = torch.constant.int 1 + %3195 = torch.prim.ListConstruct %int1_3157, %int1_3158, %int1_3159 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3196 = torch.aten.view %3194, %3195 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_3160 = torch.constant.int 32 + %3197 = torch.aten.mul.Scalar %3155, %int32_3160 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int12_3161 = torch.constant.int 12 %int1_3162 = torch.constant.int 1 - %2729 = torch.prim.ListConstruct %int4_3161, %int1_3162 : (!torch.int, !torch.int) -> !torch.list - %2730 = torch.aten.repeat %2724, %2729 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_3163 = torch.constant.int 32 - %2731 = torch.aten.mul.Scalar %2720, %int32_3163 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %3198 = torch.aten.add.Scalar %3197, %int12_3161, %int1_3162 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_3163 = torch.constant.int 2 + %3199 = torch.aten.mul.Scalar %3198, %int2_3163 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_3164 = torch.constant.int 1 - %2732 = torch.aten.add.Tensor %2731, %2728, %int1_3164 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_3165 = torch.constant.int 2 - %2733 = torch.aten.mul.Scalar %2732, %int2_3165 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %3200 = torch.aten.add.Tensor %3199, %3196, %int1_3164 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_3165 = torch.constant.int 8 + %3201 = torch.aten.mul.Scalar %3200, %int8_3165 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_3166 = torch.constant.int 1 - %2734 = torch.aten.add.Tensor %2733, %2730, %int1_3166 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %3202 = torch.aten.add.Tensor %3201, %3161, %int1_3166 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int32_3167 = torch.constant.int 32 - %2735 = torch.aten.mul.Scalar %2734, %int32_3167 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %3203 = torch.aten.mul.Scalar %3202, %int32_3167 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_3168 = torch.constant.int 1 - %2736 = torch.aten.add.Tensor %2735, %2722, %int1_3168 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %2737 = torch.prim.ListConstruct %2736 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_3169 = torch.constant.bool false - %2738 = torch.aten.index_put %2717, %2737, %2669, %false_3169 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2738, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_3170 = torch.constant.int 32 - %int2_3171 = torch.constant.int 2 - %int32_3172 = torch.constant.int 32 + %3204 = torch.aten.add.Tensor %3203, %3158, %int1_3168 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_3169 = torch.constant.int 5 + %3205 = torch.prims.convert_element_type %3130, %int5_3169 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %3206 = torch.prim.ListConstruct %3204 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_3170 = torch.constant.bool false + %3207 = torch.aten.index_put %3190, %3206, %3205, %false_3170 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3207, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_3171 = torch.constant.int 32 + %int2_3172 = torch.constant.int 2 %int8_3173 = torch.constant.int 8 - %int128_3174 = torch.constant.int 128 - %2739 = torch.prim.ListConstruct %437, %int32_3170, %int2_3171, %int32_3172, %int8_3173, %int128_3174 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2740 = torch.aten.view %2738, %2739 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2740, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3175 = torch.constant.int 2097152 - %2741 = torch.prim.ListConstruct %437, %int2097152_3175 : (!torch.int, !torch.int) -> !torch.list - %2742 = torch.aten.view %2740, %2741 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2742, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_3176 = torch.constant.int 4 - %2743 = torch.prim.ListConstruct %int4_3176, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_3177 = torch.constant.int 1 - %2744 = torch.prim.ListConstruct %358, %int1_3177 : (!torch.int, !torch.int) -> !torch.list - %int4_3178 = torch.constant.int 4 - %int0_3179 = torch.constant.int 0 - %cpu_3180 = torch.constant.device "cpu" - %false_3181 = torch.constant.bool false - %2745 = torch.aten.empty_strided %2743, %2744, %int4_3178, %int0_3179, %cpu_3180, %false_3181 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2745, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int11_3182 = torch.constant.int 11 - %2746 = torch.aten.fill.Scalar %2745, %int11_3182 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2746, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int32_3174 = torch.constant.int 32 + %int128_3175 = torch.constant.int 128 + %3208 = torch.prim.ListConstruct %456, %int32_3171, %int2_3172, %int8_3173, %int32_3174, %int128_3175 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3209 = torch.aten.view %3207, %3208 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3209, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_3176 = torch.constant.int 2097152 + %3210 = torch.prim.ListConstruct %456, %int2097152_3176 : (!torch.int, !torch.int) -> !torch.list + %3211 = torch.aten.view %3209, %3210 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %3211, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_3177 = torch.constant.none + %3212 = torch.aten.clone %176, %none_3177 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3213 = torch.aten.detach %3212 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3214 = torch.aten.detach %3213 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3215 = torch.aten.detach %3214 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_3178 = torch.constant.none + %3216 = torch.aten.clone %177, %none_3178 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3217 = torch.aten.detach %3216 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3218 = torch.aten.detach %3217 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3219 = torch.aten.detach %3218 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_3179 = torch.constant.none + %3220 = torch.aten.clone %178, %none_3179 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3221 = torch.aten.detach %3220 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3222 = torch.aten.detach %3221 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3223 = torch.aten.detach %3222 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_3180 = torch.constant.int 32 + %int2_3181 = torch.constant.int 2 + %int8_3182 = torch.constant.int 8 %int32_3183 = torch.constant.int 32 - %2747 = torch.aten.mul.Scalar %arg3, %int32_3183 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2747, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_3184 = torch.constant.int 1 - %2748 = torch.aten.add.Tensor %2747, %2746, %int1_3184 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2748, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_3185 = torch.constant.int 4 - %2749 = torch.aten.mul.int %int4_3185, %358 : !torch.int, !torch.int -> !torch.int - %2750 = torch.prim.ListConstruct %2749 : (!torch.int) -> !torch.list - %2751 = torch.aten.view %2748, %2750 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2751, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_3186 = torch.constant.int 32 - %int2_3187 = torch.constant.int 2 - %int32_3188 = torch.constant.int 32 - %int8_3189 = torch.constant.int 8 - %int128_3190 = torch.constant.int 128 - %2752 = torch.prim.ListConstruct %437, %int32_3186, %int2_3187, %int32_3188, %int8_3189, %int128_3190 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2753 = torch.aten.view %2742, %2752 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2753, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_3191 = torch.constant.int 32 - %2754 = torch.aten.mul.int %437, %int32_3191 : !torch.int, !torch.int -> !torch.int - %int2_3192 = torch.constant.int 2 - %int32_3193 = torch.constant.int 32 - %int8_3194 = torch.constant.int 8 - %int128_3195 = torch.constant.int 128 - %2755 = torch.prim.ListConstruct %2754, %int2_3192, %int32_3193, %int8_3194, %int128_3195 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2756 = torch.aten.view %2753, %2755 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %2756, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_3196 = torch.constant.int 0 - %2757 = torch.aten.index_select %2756, %int0_3196, %2751 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %2757, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_3197 = torch.constant.int 4 - %int2_3198 = torch.constant.int 2 - %int32_3199 = torch.constant.int 32 - %int8_3200 = torch.constant.int 8 - %int128_3201 = torch.constant.int 128 - %2758 = torch.prim.ListConstruct %int4_3197, %358, %int2_3198, %int32_3199, %int8_3200, %int128_3201 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2759 = torch.aten.view %2757, %2758 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2759, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_3202 = torch.constant.int 0 - %int0_3203 = torch.constant.int 0 - %int9223372036854775807_3204 = torch.constant.int 9223372036854775807 - %int1_3205 = torch.constant.int 1 - %2760 = torch.aten.slice.Tensor %2759, %int0_3202, %int0_3203, %int9223372036854775807_3204, %int1_3205 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2760, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_3206 = torch.constant.int 1 + %int128_3184 = torch.constant.int 128 + %3224 = torch.prim.ListConstruct %456, %int32_3180, %int2_3181, %int8_3182, %int32_3183, %int128_3184 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3225 = torch.aten.view %3211, %3224 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3225, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %3226 = torch_c.to_builtin_tensor %3225 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %3227 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_3185 = tensor.cast %3227 : tensor<4x?xi64> to tensor + %3228 = torch_c.to_builtin_tensor %3215 : !torch.vtensor<[],si64> -> tensor + %3229 = torch_c.to_builtin_tensor %3219 : !torch.vtensor<[],si64> -> tensor + %3230 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%3226, %cast_3185, %3228, %3229) : (tensor, tensor, tensor, tensor) -> tensor + %cast_3186 = tensor.cast %3230 : tensor to tensor<4x?x8x32x128xf16> + %3231 = torch_c.from_builtin_tensor %cast_3186 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %3231, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %3232 = torch_c.to_builtin_tensor %3225 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %3233 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_3187 = tensor.cast %3233 : tensor<4x?xi64> to tensor + %3234 = torch_c.to_builtin_tensor %3215 : !torch.vtensor<[],si64> -> tensor + %3235 = torch_c.to_builtin_tensor %3223 : !torch.vtensor<[],si64> -> tensor + %3236 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%3232, %cast_3187, %3234, %3235) : (tensor, tensor, tensor, tensor) -> tensor + %cast_3188 = tensor.cast %3236 : tensor to tensor<4x?x8x32x128xf16> + %3237 = torch_c.from_builtin_tensor %cast_3188 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %3237, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_3189 = torch.constant.int 2 + %int3_3190 = torch.constant.int 3 + %3238 = torch.aten.transpose.int %3231, %int2_3189, %int3_3190 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3238, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_3191 = torch.constant.int 0 + %3239 = torch.aten.clone %3238, %int0_3191 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3239, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_3192 = torch.constant.int 4 + %int8_3193 = torch.constant.int 8 + %int128_3194 = torch.constant.int 128 + %3240 = torch.prim.ListConstruct %int4_3192, %457, %int8_3193, %int128_3194 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3241 = torch.aten._unsafe_view %3239, %3240 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3241, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_3195 = torch.constant.int 2 + %int3_3196 = torch.constant.int 3 + %3242 = torch.aten.transpose.int %3237, %int2_3195, %int3_3196 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3242, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_3197 = torch.constant.int 0 + %3243 = torch.aten.clone %3242, %int0_3197 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3243, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_3198 = torch.constant.int 4 + %int8_3199 = torch.constant.int 8 + %int128_3200 = torch.constant.int 128 + %3244 = torch.prim.ListConstruct %int4_3198, %457, %int8_3199, %int128_3200 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3245 = torch.aten._unsafe_view %3243, %3244 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3245, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_3201 = torch.constant.int -2 + %3246 = torch.aten.unsqueeze %3241, %int-2_3201 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3246, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_3202 = torch.constant.int 4 + %int8_3203 = torch.constant.int 8 + %int4_3204 = torch.constant.int 4 + %int128_3205 = torch.constant.int 128 + %3247 = torch.prim.ListConstruct %int4_3202, %457, %int8_3203, %int4_3204, %int128_3205 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_3206 = torch.constant.bool false + %3248 = torch.aten.expand %3246, %3247, %false_3206 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3248, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int0_3207 = torch.constant.int 0 - %int9223372036854775807_3208 = torch.constant.int 9223372036854775807 - %int1_3209 = torch.constant.int 1 - %2761 = torch.aten.slice.Tensor %2760, %int1_3206, %int0_3207, %int9223372036854775807_3208, %int1_3209 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2761, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_3210 = torch.constant.int 2 - %int0_3211 = torch.constant.int 0 - %2762 = torch.aten.select.int %2761, %int2_3210, %int0_3211 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2762, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_3212 = torch.constant.int 32 - %2763 = torch.aten.mul.int %358, %int32_3212 : !torch.int, !torch.int -> !torch.int - %int2_3213 = torch.constant.int 2 - %int0_3214 = torch.constant.int 0 - %int1_3215 = torch.constant.int 1 - %2764 = torch.aten.slice.Tensor %2762, %int2_3213, %int0_3214, %2763, %int1_3215 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2764, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_3216 = torch.constant.int 0 - %2765 = torch.aten.clone %2764, %int0_3216 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2765, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_3217 = torch.constant.int 1 - %2766 = torch.aten.size.int %2761, %int1_3217 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_3218 = torch.constant.int 32 - %2767 = torch.aten.mul.int %2766, %int32_3218 : !torch.int, !torch.int -> !torch.int - %int4_3219 = torch.constant.int 4 - %int8_3220 = torch.constant.int 8 - %int128_3221 = torch.constant.int 128 - %2768 = torch.prim.ListConstruct %int4_3219, %2767, %int8_3220, %int128_3221 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2769 = torch.aten._unsafe_view %2765, %2768 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2769, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_3222 = torch.constant.int 0 - %int0_3223 = torch.constant.int 0 - %int9223372036854775807_3224 = torch.constant.int 9223372036854775807 + %3249 = torch.aten.clone %3248, %int0_3207 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3249, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_3208 = torch.constant.int 4 + %int32_3209 = torch.constant.int 32 + %int128_3210 = torch.constant.int 128 + %3250 = torch.prim.ListConstruct %int4_3208, %457, %int32_3209, %int128_3210 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3251 = torch.aten._unsafe_view %3249, %3250 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3251, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_3211 = torch.constant.int -2 + %3252 = torch.aten.unsqueeze %3245, %int-2_3211 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3252, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_3212 = torch.constant.int 4 + %int8_3213 = torch.constant.int 8 + %int4_3214 = torch.constant.int 4 + %int128_3215 = torch.constant.int 128 + %3253 = torch.prim.ListConstruct %int4_3212, %457, %int8_3213, %int4_3214, %int128_3215 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_3216 = torch.constant.bool false + %3254 = torch.aten.expand %3252, %3253, %false_3216 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3254, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_3217 = torch.constant.int 0 + %3255 = torch.aten.clone %3254, %int0_3217 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3255, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_3218 = torch.constant.int 4 + %int32_3219 = torch.constant.int 32 + %int128_3220 = torch.constant.int 128 + %3256 = torch.prim.ListConstruct %int4_3218, %457, %int32_3219, %int128_3220 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3257 = torch.aten._unsafe_view %3255, %3256 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3257, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_3221 = torch.constant.int 1 + %int2_3222 = torch.constant.int 2 + %3258 = torch.aten.transpose.int %3140, %int1_3221, %int2_3222 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_3223 = torch.constant.int 1 + %int2_3224 = torch.constant.int 2 + %3259 = torch.aten.transpose.int %3251, %int1_3223, %int2_3224 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3259, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_3225 = torch.constant.int 1 - %2770 = torch.aten.slice.Tensor %2769, %int0_3222, %int0_3223, %int9223372036854775807_3224, %int1_3225 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2770, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_3226 = torch.constant.int 0 - %int0_3227 = torch.constant.int 0 - %int9223372036854775807_3228 = torch.constant.int 9223372036854775807 - %int1_3229 = torch.constant.int 1 - %2771 = torch.aten.slice.Tensor %2759, %int0_3226, %int0_3227, %int9223372036854775807_3228, %int1_3229 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2771, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> + %int2_3226 = torch.constant.int 2 + %3260 = torch.aten.transpose.int %3257, %int1_3225, %int2_3226 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3260, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_3227 = torch.constant.float 0.000000e+00 + %false_3228 = torch.constant.bool false + %none_3229 = torch.constant.none + %3261:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3258, %3259, %3260, %float0.000000e00_3227, %false_3228, %470, %none_3229) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) %int1_3230 = torch.constant.int 1 - %int0_3231 = torch.constant.int 0 - %int9223372036854775807_3232 = torch.constant.int 9223372036854775807 + %int2_3231 = torch.constant.int 2 + %3262 = torch.aten.transpose.int %3261#0, %int1_3230, %int2_3231 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_3232 = torch.constant.int 4 %int1_3233 = torch.constant.int 1 - %2772 = torch.aten.slice.Tensor %2771, %int1_3230, %int0_3231, %int9223372036854775807_3232, %int1_3233 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2772, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_3234 = torch.constant.int 2 - %int1_3235 = torch.constant.int 1 - %2773 = torch.aten.select.int %2772, %int2_3234, %int1_3235 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2773, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_3236 = torch.constant.int 2 - %int0_3237 = torch.constant.int 0 - %int1_3238 = torch.constant.int 1 - %2774 = torch.aten.slice.Tensor %2773, %int2_3236, %int0_3237, %2763, %int1_3238 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2774, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_3239 = torch.constant.int 0 - %2775 = torch.aten.clone %2774, %int0_3239 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2775, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_3240 = torch.constant.int 1 - %2776 = torch.aten.size.int %2772, %int1_3240 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_3241 = torch.constant.int 32 - %2777 = torch.aten.mul.int %2776, %int32_3241 : !torch.int, !torch.int -> !torch.int - %int4_3242 = torch.constant.int 4 - %int8_3243 = torch.constant.int 8 - %int128_3244 = torch.constant.int 128 - %2778 = torch.prim.ListConstruct %int4_3242, %2777, %int8_3243, %int128_3244 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2779 = torch.aten._unsafe_view %2775, %2778 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2779, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_3245 = torch.constant.int 0 - %int0_3246 = torch.constant.int 0 - %int9223372036854775807_3247 = torch.constant.int 9223372036854775807 - %int1_3248 = torch.constant.int 1 - %2780 = torch.aten.slice.Tensor %2779, %int0_3245, %int0_3246, %int9223372036854775807_3247, %int1_3248 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2780, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_3249 = torch.constant.int -2 - %2781 = torch.aten.unsqueeze %2770, %int-2_3249 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2781, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4096_3234 = torch.constant.int 4096 + %3263 = torch.prim.ListConstruct %int4_3232, %int1_3233, %int4096_3234 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3264 = torch.aten.view %3262, %3263 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_3235 = torch.constant.int -2 + %int-1_3236 = torch.constant.int -1 + %3265 = torch.aten.transpose.int %179, %int-2_3235, %int-1_3236 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_3237 = torch.constant.int 5 + %3266 = torch.prims.convert_element_type %3265, %int5_3237 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_3238 = torch.constant.int 4 + %int4096_3239 = torch.constant.int 4096 + %3267 = torch.prim.ListConstruct %int4_3238, %int4096_3239 : (!torch.int, !torch.int) -> !torch.list + %3268 = torch.aten.view %3264, %3267 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3269 = torch.aten.mm %3268, %3266 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_3240 = torch.constant.int 4 + %int1_3241 = torch.constant.int 1 + %int4096_3242 = torch.constant.int 4096 + %3270 = torch.prim.ListConstruct %int4_3240, %int1_3241, %int4096_3242 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3271 = torch.aten.view %3269, %3270 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_3243 = torch.constant.int 1 + %3272 = torch.aten.add.Tensor %3093, %3271, %int1_3243 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_3244 = torch.constant.int 6 + %3273 = torch.prims.convert_element_type %3272, %int6_3244 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_3245 = torch.constant.int 2 + %3274 = torch.aten.pow.Tensor_Scalar %3273, %int2_3245 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_3246 = torch.constant.int -1 + %3275 = torch.prim.ListConstruct %int-1_3246 : (!torch.int) -> !torch.list + %true_3247 = torch.constant.bool true + %none_3248 = torch.constant.none + %3276 = torch.aten.mean.dim %3274, %3275, %true_3247, %none_3248 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_3249 = torch.constant.float 9.9999997473787516E-6 %int1_3250 = torch.constant.int 1 - %2782 = torch.aten.size.int %2769, %int1_3250 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_3251 = torch.constant.int 4 - %int8_3252 = torch.constant.int 8 - %int4_3253 = torch.constant.int 4 - %int128_3254 = torch.constant.int 128 - %2783 = torch.prim.ListConstruct %int4_3251, %2782, %int8_3252, %int4_3253, %int128_3254 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_3255 = torch.constant.bool false - %2784 = torch.aten.expand %2781, %2783, %false_3255 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2784, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_3256 = torch.constant.int 0 - %2785 = torch.aten.clone %2784, %int0_3256 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2785, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_3257 = torch.constant.int 4 - %int32_3258 = torch.constant.int 32 - %int128_3259 = torch.constant.int 128 - %2786 = torch.prim.ListConstruct %int4_3257, %2782, %int32_3258, %int128_3259 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2787 = torch.aten._unsafe_view %2785, %2786 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2787, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_3260 = torch.constant.int -2 - %2788 = torch.aten.unsqueeze %2780, %int-2_3260 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2788, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_3261 = torch.constant.int 1 - %2789 = torch.aten.size.int %2779, %int1_3261 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_3262 = torch.constant.int 4 - %int8_3263 = torch.constant.int 8 + %3277 = torch.aten.add.Scalar %3276, %float9.999990e-06_3249, %int1_3250 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %3278 = torch.aten.rsqrt %3277 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %3279 = torch.aten.mul.Tensor %3273, %3278 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_3251 = torch.constant.int 5 + %3280 = torch.prims.convert_element_type %3279, %int5_3251 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %3281 = torch.aten.mul.Tensor %180, %3280 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_3252 = torch.constant.int 5 + %3282 = torch.prims.convert_element_type %3281, %int5_3252 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_3253 = torch.constant.int -2 + %int-1_3254 = torch.constant.int -1 + %3283 = torch.aten.transpose.int %181, %int-2_3253, %int-1_3254 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3255 = torch.constant.int 5 + %3284 = torch.prims.convert_element_type %3283, %int5_3255 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_3256 = torch.constant.int 4 + %int4096_3257 = torch.constant.int 4096 + %3285 = torch.prim.ListConstruct %int4_3256, %int4096_3257 : (!torch.int, !torch.int) -> !torch.list + %3286 = torch.aten.view %3282, %3285 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3287 = torch.aten.mm %3286, %3284 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_3258 = torch.constant.int 4 + %int1_3259 = torch.constant.int 1 + %int14336_3260 = torch.constant.int 14336 + %3288 = torch.prim.ListConstruct %int4_3258, %int1_3259, %int14336_3260 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3289 = torch.aten.view %3287, %3288 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %3290 = torch.aten.silu %3289 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_3261 = torch.constant.int -2 + %int-1_3262 = torch.constant.int -1 + %3291 = torch.aten.transpose.int %182, %int-2_3261, %int-1_3262 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3263 = torch.constant.int 5 + %3292 = torch.prims.convert_element_type %3291, %int5_3263 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> %int4_3264 = torch.constant.int 4 - %int128_3265 = torch.constant.int 128 - %2790 = torch.prim.ListConstruct %int4_3262, %2789, %int8_3263, %int4_3264, %int128_3265 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_3266 = torch.constant.bool false - %2791 = torch.aten.expand %2788, %2790, %false_3266 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2791, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_3267 = torch.constant.int 0 - %2792 = torch.aten.clone %2791, %int0_3267 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2792, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_3268 = torch.constant.int 4 - %int32_3269 = torch.constant.int 32 - %int128_3270 = torch.constant.int 128 - %2793 = torch.prim.ListConstruct %int4_3268, %2789, %int32_3269, %int128_3270 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2794 = torch.aten._unsafe_view %2792, %2793 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2794, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_3271 = torch.constant.int 1 - %int2_3272 = torch.constant.int 2 - %2795 = torch.aten.transpose.int %2675, %int1_3271, %int2_3272 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_3273 = torch.constant.int 1 - %int2_3274 = torch.constant.int 2 - %2796 = torch.aten.transpose.int %2787, %int1_3273, %int2_3274 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2796, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int4096_3265 = torch.constant.int 4096 + %3293 = torch.prim.ListConstruct %int4_3264, %int4096_3265 : (!torch.int, !torch.int) -> !torch.list + %3294 = torch.aten.view %3282, %3293 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3295 = torch.aten.mm %3294, %3292 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_3266 = torch.constant.int 4 + %int1_3267 = torch.constant.int 1 + %int14336_3268 = torch.constant.int 14336 + %3296 = torch.prim.ListConstruct %int4_3266, %int1_3267, %int14336_3268 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3297 = torch.aten.view %3295, %3296 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %3298 = torch.aten.mul.Tensor %3290, %3297 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_3269 = torch.constant.int -2 + %int-1_3270 = torch.constant.int -1 + %3299 = torch.aten.transpose.int %183, %int-2_3269, %int-1_3270 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_3271 = torch.constant.int 5 + %3300 = torch.prims.convert_element_type %3299, %int5_3271 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_3272 = torch.constant.int 4 + %int14336_3273 = torch.constant.int 14336 + %3301 = torch.prim.ListConstruct %int4_3272, %int14336_3273 : (!torch.int, !torch.int) -> !torch.list + %3302 = torch.aten.view %3298, %3301 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %3303 = torch.aten.mm %3302, %3300 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_3274 = torch.constant.int 4 %int1_3275 = torch.constant.int 1 - %int2_3276 = torch.constant.int 2 - %2797 = torch.aten.transpose.int %2794, %int1_3275, %int2_3276 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %2797, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_3277 = torch.constant.float 0.000000e+00 - %false_3278 = torch.constant.bool false - %none_3279 = torch.constant.none - %2798:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2795, %2796, %2797, %float0.000000e00_3277, %false_3278, %368, %none_3279) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_3280 = torch.constant.int 1 - %int2_3281 = torch.constant.int 2 - %2799 = torch.aten.transpose.int %2798#0, %int1_3280, %int2_3281 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_3282 = torch.constant.int 4 - %int1_3283 = torch.constant.int 1 - %int4096_3284 = torch.constant.int 4096 - %2800 = torch.prim.ListConstruct %int4_3282, %int1_3283, %int4096_3284 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2801 = torch.aten.view %2799, %2800 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_3285 = torch.constant.int -2 - %int-1_3286 = torch.constant.int -1 - %2802 = torch.aten.transpose.int %128, %int-2_3285, %int-1_3286 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3287 = torch.constant.int 4 - %int4096_3288 = torch.constant.int 4096 - %2803 = torch.prim.ListConstruct %int4_3287, %int4096_3288 : (!torch.int, !torch.int) -> !torch.list - %2804 = torch.aten.view %2801, %2803 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2805 = torch.aten.mm %2804, %2802 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_3289 = torch.constant.int 4 - %int1_3290 = torch.constant.int 1 + %int4096_3276 = torch.constant.int 4096 + %3304 = torch.prim.ListConstruct %int4_3274, %int1_3275, %int4096_3276 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3305 = torch.aten.view %3303, %3304 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_3277 = torch.constant.int 1 + %3306 = torch.aten.add.Tensor %3272, %3305, %int1_3277 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_3278 = torch.constant.int 6 + %3307 = torch.prims.convert_element_type %3306, %int6_3278 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_3279 = torch.constant.int 2 + %3308 = torch.aten.pow.Tensor_Scalar %3307, %int2_3279 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_3280 = torch.constant.int -1 + %3309 = torch.prim.ListConstruct %int-1_3280 : (!torch.int) -> !torch.list + %true_3281 = torch.constant.bool true + %none_3282 = torch.constant.none + %3310 = torch.aten.mean.dim %3308, %3309, %true_3281, %none_3282 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_3283 = torch.constant.float 9.9999997473787516E-6 + %int1_3284 = torch.constant.int 1 + %3311 = torch.aten.add.Scalar %3310, %float9.999990e-06_3283, %int1_3284 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %3312 = torch.aten.rsqrt %3311 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %3313 = torch.aten.mul.Tensor %3307, %3312 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_3285 = torch.constant.int 5 + %3314 = torch.prims.convert_element_type %3313, %int5_3285 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %3315 = torch.aten.mul.Tensor %184, %3314 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_3286 = torch.constant.int 5 + %3316 = torch.prims.convert_element_type %3315, %int5_3286 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_3287 = torch.constant.int -2 + %int-1_3288 = torch.constant.int -1 + %3317 = torch.aten.transpose.int %185, %int-2_3287, %int-1_3288 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_3289 = torch.constant.int 5 + %3318 = torch.prims.convert_element_type %3317, %int5_3289 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_3290 = torch.constant.int 4 %int4096_3291 = torch.constant.int 4096 - %2806 = torch.prim.ListConstruct %int4_3289, %int1_3290, %int4096_3291 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2807 = torch.aten.view %2805, %2806 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_3292 = torch.constant.int 1 - %2808 = torch.aten.add.Tensor %2635, %2807, %int1_3292 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_3293 = torch.constant.int 6 - %2809 = torch.prims.convert_element_type %2808, %int6_3293 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_3294 = torch.constant.int 2 - %2810 = torch.aten.pow.Tensor_Scalar %2809, %int2_3294 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_3295 = torch.constant.int -1 - %2811 = torch.prim.ListConstruct %int-1_3295 : (!torch.int) -> !torch.list - %true_3296 = torch.constant.bool true - %none_3297 = torch.constant.none - %2812 = torch.aten.mean.dim %2810, %2811, %true_3296, %none_3297 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_3298 = torch.constant.float 9.9999997473787516E-6 - %int1_3299 = torch.constant.int 1 - %2813 = torch.aten.add.Scalar %2812, %float9.999990e-06_3298, %int1_3299 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %2814 = torch.aten.rsqrt %2813 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %2815 = torch.aten.mul.Tensor %2809, %2814 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_3300 = torch.constant.int 5 - %2816 = torch.prims.convert_element_type %2815, %int5_3300 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %2817 = torch.aten.mul.Tensor %129, %2816 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_3301 = torch.constant.int 5 - %2818 = torch.prims.convert_element_type %2817, %int5_3301 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_3302 = torch.constant.int -2 - %int-1_3303 = torch.constant.int -1 - %2819 = torch.aten.transpose.int %130, %int-2_3302, %int-1_3303 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3304 = torch.constant.int 4 - %int4096_3305 = torch.constant.int 4096 - %2820 = torch.prim.ListConstruct %int4_3304, %int4096_3305 : (!torch.int, !torch.int) -> !torch.list - %2821 = torch.aten.view %2818, %2820 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2822 = torch.aten.mm %2821, %2819 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %3319 = torch.prim.ListConstruct %int4_3290, %int4096_3291 : (!torch.int, !torch.int) -> !torch.list + %3320 = torch.aten.view %3316, %3319 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3321 = torch.aten.mm %3320, %3318 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_3292 = torch.constant.int 4 + %int1_3293 = torch.constant.int 1 + %int4096_3294 = torch.constant.int 4096 + %3322 = torch.prim.ListConstruct %int4_3292, %int1_3293, %int4096_3294 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3323 = torch.aten.view %3321, %3322 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_3295 = torch.constant.int -2 + %int-1_3296 = torch.constant.int -1 + %3324 = torch.aten.transpose.int %186, %int-2_3295, %int-1_3296 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_3297 = torch.constant.int 5 + %3325 = torch.prims.convert_element_type %3324, %int5_3297 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_3298 = torch.constant.int 4 + %int4096_3299 = torch.constant.int 4096 + %3326 = torch.prim.ListConstruct %int4_3298, %int4096_3299 : (!torch.int, !torch.int) -> !torch.list + %3327 = torch.aten.view %3316, %3326 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3328 = torch.aten.mm %3327, %3325 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_3300 = torch.constant.int 4 + %int1_3301 = torch.constant.int 1 + %int1024_3302 = torch.constant.int 1024 + %3329 = torch.prim.ListConstruct %int4_3300, %int1_3301, %int1024_3302 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3330 = torch.aten.view %3328, %3329 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_3303 = torch.constant.int -2 + %int-1_3304 = torch.constant.int -1 + %3331 = torch.aten.transpose.int %187, %int-2_3303, %int-1_3304 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_3305 = torch.constant.int 5 + %3332 = torch.prims.convert_element_type %3331, %int5_3305 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> %int4_3306 = torch.constant.int 4 - %int1_3307 = torch.constant.int 1 - %int14336_3308 = torch.constant.int 14336 - %2823 = torch.prim.ListConstruct %int4_3306, %int1_3307, %int14336_3308 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2824 = torch.aten.view %2822, %2823 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %2825 = torch.aten.silu %2824 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_3309 = torch.constant.int -2 - %int-1_3310 = torch.constant.int -1 - %2826 = torch.aten.transpose.int %131, %int-2_3309, %int-1_3310 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4096_3307 = torch.constant.int 4096 + %3333 = torch.prim.ListConstruct %int4_3306, %int4096_3307 : (!torch.int, !torch.int) -> !torch.list + %3334 = torch.aten.view %3316, %3333 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3335 = torch.aten.mm %3334, %3332 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_3308 = torch.constant.int 4 + %int1_3309 = torch.constant.int 1 + %int1024_3310 = torch.constant.int 1024 + %3336 = torch.prim.ListConstruct %int4_3308, %int1_3309, %int1024_3310 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3337 = torch.aten.view %3335, %3336 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> %int4_3311 = torch.constant.int 4 - %int4096_3312 = torch.constant.int 4096 - %2827 = torch.prim.ListConstruct %int4_3311, %int4096_3312 : (!torch.int, !torch.int) -> !torch.list - %2828 = torch.aten.view %2818, %2827 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2829 = torch.aten.mm %2828, %2826 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_3313 = torch.constant.int 4 - %int1_3314 = torch.constant.int 1 - %int14336_3315 = torch.constant.int 14336 - %2830 = torch.prim.ListConstruct %int4_3313, %int1_3314, %int14336_3315 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2831 = torch.aten.view %2829, %2830 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %2832 = torch.aten.mul.Tensor %2825, %2831 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_3316 = torch.constant.int -2 - %int-1_3317 = torch.constant.int -1 - %2833 = torch.aten.transpose.int %132, %int-2_3316, %int-1_3317 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_3318 = torch.constant.int 4 - %int14336_3319 = torch.constant.int 14336 - %2834 = torch.prim.ListConstruct %int4_3318, %int14336_3319 : (!torch.int, !torch.int) -> !torch.list - %2835 = torch.aten.view %2832, %2834 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %2836 = torch.aten.mm %2835, %2833 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_3320 = torch.constant.int 4 - %int1_3321 = torch.constant.int 1 - %int4096_3322 = torch.constant.int 4096 - %2837 = torch.prim.ListConstruct %int4_3320, %int1_3321, %int4096_3322 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2838 = torch.aten.view %2836, %2837 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_3312 = torch.constant.int 1 + %int32_3313 = torch.constant.int 32 + %int128_3314 = torch.constant.int 128 + %3338 = torch.prim.ListConstruct %int4_3311, %int1_3312, %int32_3313, %int128_3314 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3339 = torch.aten.view %3323, %3338 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_3315 = torch.constant.int 4 + %int1_3316 = torch.constant.int 1 + %int8_3317 = torch.constant.int 8 + %int128_3318 = torch.constant.int 128 + %3340 = torch.prim.ListConstruct %int4_3315, %int1_3316, %int8_3317, %int128_3318 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3341 = torch.aten.view %3330, %3340 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_3319 = torch.constant.int 4 + %int1_3320 = torch.constant.int 1 + %int8_3321 = torch.constant.int 8 + %int128_3322 = torch.constant.int 128 + %3342 = torch.prim.ListConstruct %int4_3319, %int1_3320, %int8_3321, %int128_3322 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3343 = torch.aten.view %3337, %3342 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> %int1_3323 = torch.constant.int 1 - %2839 = torch.aten.add.Tensor %2808, %2838, %int1_3323 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_3324 = torch.constant.int 6 - %2840 = torch.prims.convert_element_type %2839, %int6_3324 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_3325 = torch.constant.int 2 - %2841 = torch.aten.pow.Tensor_Scalar %2840, %int2_3325 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_3326 = torch.constant.int -1 - %2842 = torch.prim.ListConstruct %int-1_3326 : (!torch.int) -> !torch.list - %true_3327 = torch.constant.bool true - %none_3328 = torch.constant.none - %2843 = torch.aten.mean.dim %2841, %2842, %true_3327, %none_3328 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_3329 = torch.constant.float 9.9999997473787516E-6 - %int1_3330 = torch.constant.int 1 - %2844 = torch.aten.add.Scalar %2843, %float9.999990e-06_3329, %int1_3330 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %2845 = torch.aten.rsqrt %2844 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %2846 = torch.aten.mul.Tensor %2840, %2845 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_3331 = torch.constant.int 5 - %2847 = torch.prims.convert_element_type %2846, %int5_3331 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %2848 = torch.aten.mul.Tensor %133, %2847 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_3332 = torch.constant.int 5 - %2849 = torch.prims.convert_element_type %2848, %int5_3332 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_3333 = torch.constant.int -2 - %int-1_3334 = torch.constant.int -1 - %2850 = torch.aten.transpose.int %134, %int-2_3333, %int-1_3334 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3335 = torch.constant.int 4 - %int4096_3336 = torch.constant.int 4096 - %2851 = torch.prim.ListConstruct %int4_3335, %int4096_3336 : (!torch.int, !torch.int) -> !torch.list - %2852 = torch.aten.view %2849, %2851 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2853 = torch.aten.mm %2852, %2850 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_3337 = torch.constant.int 4 - %int1_3338 = torch.constant.int 1 - %int4096_3339 = torch.constant.int 4096 - %2854 = torch.prim.ListConstruct %int4_3337, %int1_3338, %int4096_3339 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2855 = torch.aten.view %2853, %2854 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_3340 = torch.constant.int -2 - %int-1_3341 = torch.constant.int -1 - %2856 = torch.aten.transpose.int %135, %int-2_3340, %int-1_3341 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3342 = torch.constant.int 4 - %int4096_3343 = torch.constant.int 4096 - %2857 = torch.prim.ListConstruct %int4_3342, %int4096_3343 : (!torch.int, !torch.int) -> !torch.list - %2858 = torch.aten.view %2849, %2857 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2859 = torch.aten.mm %2858, %2856 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_3344 = torch.constant.int 4 - %int1_3345 = torch.constant.int 1 - %int1024_3346 = torch.constant.int 1024 - %2860 = torch.prim.ListConstruct %int4_3344, %int1_3345, %int1024_3346 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2861 = torch.aten.view %2859, %2860 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_3347 = torch.constant.int -2 - %int-1_3348 = torch.constant.int -1 - %2862 = torch.aten.transpose.int %136, %int-2_3347, %int-1_3348 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3349 = torch.constant.int 4 - %int4096_3350 = torch.constant.int 4096 - %2863 = torch.prim.ListConstruct %int4_3349, %int4096_3350 : (!torch.int, !torch.int) -> !torch.list - %2864 = torch.aten.view %2849, %2863 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %2865 = torch.aten.mm %2864, %2862 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_3351 = torch.constant.int 4 + %int2_3324 = torch.constant.int 2 + %3344 = torch.aten.transpose.int %3339, %int1_3323, %int2_3324 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %3345 = torch.aten.mul.Tensor %3344, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_3325 = torch.constant.int 3 + %int0_3326 = torch.constant.int 0 + %int64_3327 = torch.constant.int 64 + %int1_3328 = torch.constant.int 1 + %3346 = torch.aten.slice.Tensor %3344, %int3_3325, %int0_3326, %int64_3327, %int1_3328 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_3329 = torch.constant.int 3 + %int64_3330 = torch.constant.int 64 + %int9223372036854775807_3331 = torch.constant.int 9223372036854775807 + %int1_3332 = torch.constant.int 1 + %3347 = torch.aten.slice.Tensor %3344, %int3_3329, %int64_3330, %int9223372036854775807_3331, %int1_3332 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %3348 = torch.aten.neg %3347 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %3349 = torch.prim.ListConstruct %3348, %3346 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_3333 = torch.constant.int -1 + %3350 = torch.aten.cat %3349, %int-1_3333 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %3351 = torch.aten.mul.Tensor %3350, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_3334 = torch.constant.int 1 + %3352 = torch.aten.add.Tensor %3345, %3351, %int1_3334 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_3335 = torch.constant.int 1 + %int2_3336 = torch.constant.int 2 + %3353 = torch.aten.transpose.int %3352, %int1_3335, %int2_3336 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_3337 = torch.constant.int 1 + %int2_3338 = torch.constant.int 2 + %3354 = torch.aten.transpose.int %3341, %int1_3337, %int2_3338 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %3355 = torch.aten.mul.Tensor %3354, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_3339 = torch.constant.int 3 + %int0_3340 = torch.constant.int 0 + %int64_3341 = torch.constant.int 64 + %int1_3342 = torch.constant.int 1 + %3356 = torch.aten.slice.Tensor %3354, %int3_3339, %int0_3340, %int64_3341, %int1_3342 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_3343 = torch.constant.int 3 + %int64_3344 = torch.constant.int 64 + %int9223372036854775807_3345 = torch.constant.int 9223372036854775807 + %int1_3346 = torch.constant.int 1 + %3357 = torch.aten.slice.Tensor %3354, %int3_3343, %int64_3344, %int9223372036854775807_3345, %int1_3346 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %3358 = torch.aten.neg %3357 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %3359 = torch.prim.ListConstruct %3358, %3356 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_3347 = torch.constant.int -1 + %3360 = torch.aten.cat %3359, %int-1_3347 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %3361 = torch.aten.mul.Tensor %3360, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_3348 = torch.constant.int 1 + %3362 = torch.aten.add.Tensor %3355, %3361, %int1_3348 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_3349 = torch.constant.int 1 + %int2_3350 = torch.constant.int 2 + %3363 = torch.aten.transpose.int %3362, %int1_3349, %int2_3350 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_3351 = torch.constant.int 32 + %3364 = torch.aten.floor_divide.Scalar %arg2, %int32_3351 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> %int1_3352 = torch.constant.int 1 - %int1024_3353 = torch.constant.int 1024 - %2866 = torch.prim.ListConstruct %int4_3351, %int1_3352, %int1024_3353 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2867 = torch.aten.view %2865, %2866 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_3354 = torch.constant.int 4 - %int1_3355 = torch.constant.int 1 - %int32_3356 = torch.constant.int 32 - %int128_3357 = torch.constant.int 128 - %2868 = torch.prim.ListConstruct %int4_3354, %int1_3355, %int32_3356, %int128_3357 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2869 = torch.aten.view %2855, %2868 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_3358 = torch.constant.int 4 - %int1_3359 = torch.constant.int 1 - %int8_3360 = torch.constant.int 8 - %int128_3361 = torch.constant.int 128 - %2870 = torch.prim.ListConstruct %int4_3358, %int1_3359, %int8_3360, %int128_3361 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2871 = torch.aten.view %2861, %2870 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_3362 = torch.constant.int 4 - %int1_3363 = torch.constant.int 1 - %int8_3364 = torch.constant.int 8 - %int128_3365 = torch.constant.int 128 - %2872 = torch.prim.ListConstruct %int4_3362, %int1_3363, %int8_3364, %int128_3365 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2873 = torch.aten.view %2867, %2872 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_3366 = torch.constant.int 6 - %2874 = torch.prims.convert_element_type %2869, %int6_3366 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %2875 = torch_c.to_builtin_tensor %2874 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %2876 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %2877 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%2875, %2876) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %2878 = torch_c.from_builtin_tensor %2877 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_3367 = torch.constant.int 5 - %2879 = torch.prims.convert_element_type %2878, %int5_3367 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_3368 = torch.constant.int 6 - %2880 = torch.prims.convert_element_type %2871, %int6_3368 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %2881 = torch_c.to_builtin_tensor %2880 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %2882 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %2883 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%2881, %2882) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %2884 = torch_c.from_builtin_tensor %2883 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_3369 = torch.constant.int 5 - %2885 = torch.prims.convert_element_type %2884, %int5_3369 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_3370 = torch.constant.int 32 - %2886 = torch.aten.floor_divide.Scalar %arg2, %int32_3370 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %3365 = torch.aten.unsqueeze %3364, %int1_3352 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_3353 = torch.constant.int 1 + %false_3354 = torch.constant.bool false + %3366 = torch.aten.gather %arg3, %int1_3353, %3365, %false_3354 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_3355 = torch.constant.int 4 + %int1_3356 = torch.constant.int 1 + %int1_3357 = torch.constant.int 1 + %3367 = torch.prim.ListConstruct %int4_3355, %int1_3356, %int1_3357 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3368 = torch.aten.view %3366, %3367 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_3358 = torch.constant.int 32 + %3369 = torch.aten.remainder.Scalar %arg2, %int32_3358 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_3359 = torch.constant.int 4 + %int1_3360 = torch.constant.int 1 + %int1_3361 = torch.constant.int 1 + %3370 = torch.prim.ListConstruct %int4_3359, %int1_3360, %int1_3361 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3371 = torch.aten.view %3369, %3370 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_3362 = torch.constant.int 8 + %none_3363 = torch.constant.none + %none_3364 = torch.constant.none + %cpu_3365 = torch.constant.device "cpu" + %false_3366 = torch.constant.bool false + %3372 = torch.aten.arange %int8_3362, %none_3363, %none_3364, %cpu_3365, %false_3366 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_3367 = torch.constant.int 1 + %int1_3368 = torch.constant.int 1 + %int8_3369 = torch.constant.int 8 + %3373 = torch.prim.ListConstruct %int1_3367, %int1_3368, %int8_3369 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3374 = torch.aten.view %3372, %3373 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_3370 = torch.constant.none + %3375 = torch.aten.clone %188, %none_3370 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3376 = torch.aten.detach %3375 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3377 = torch.aten.detach %3376 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3378 = torch.aten.detach %3377 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %int1_3371 = torch.constant.int 1 - %2887 = torch.aten.unsqueeze %2886, %int1_3371 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> %int1_3372 = torch.constant.int 1 - %false_3373 = torch.constant.bool false - %2888 = torch.aten.gather %arg3, %int1_3372, %2887, %false_3373 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int1_3373 = torch.constant.int 1 + %3379 = torch.prim.ListConstruct %int1_3371, %int1_3372, %int1_3373 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3380 = torch.aten.view %3378, %3379 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> %int32_3374 = torch.constant.int 32 - %2889 = torch.aten.remainder.Scalar %arg2, %int32_3374 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %3381 = torch.aten.mul.Scalar %3368, %int32_3374 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int13 = torch.constant.int 13 %int1_3375 = torch.constant.int 1 - %2890 = torch.aten.unsqueeze %2889, %int1_3375 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_3376 = torch.constant.none - %2891 = torch.aten.clone %137, %none_3376 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_3377 = torch.constant.int 0 - %2892 = torch.aten.unsqueeze %2891, %int0_3377 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_3378 = torch.constant.int 4 + %3382 = torch.aten.add.Scalar %3381, %int13, %int1_3375 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_3376 = torch.constant.int 2 + %3383 = torch.aten.mul.Scalar %3382, %int2_3376 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_3377 = torch.constant.int 1 + %3384 = torch.aten.add.Tensor %3383, %3380, %int1_3377 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_3378 = torch.constant.int 8 + %3385 = torch.aten.mul.Scalar %3384, %int8_3378 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_3379 = torch.constant.int 1 - %2893 = torch.prim.ListConstruct %int4_3378, %int1_3379 : (!torch.int, !torch.int) -> !torch.list - %int1_3380 = torch.constant.int 1 + %3386 = torch.aten.add.Tensor %3385, %3374, %int1_3379 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_3380 = torch.constant.int 32 + %3387 = torch.aten.mul.Scalar %3386, %int32_3380 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_3381 = torch.constant.int 1 - %2894 = torch.prim.ListConstruct %int1_3380, %int1_3381 : (!torch.int, !torch.int) -> !torch.list - %int4_3382 = torch.constant.int 4 - %int0_3383 = torch.constant.int 0 - %cpu_3384 = torch.constant.device "cpu" - %false_3385 = torch.constant.bool false - %2895 = torch.aten.empty_strided %2893, %2894, %int4_3382, %int0_3383, %cpu_3384, %false_3385 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int12 = torch.constant.int 12 - %2896 = torch.aten.fill.Scalar %2895, %int12 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_3386 = torch.constant.int 4 - %int1_3387 = torch.constant.int 1 - %2897 = torch.prim.ListConstruct %int4_3386, %int1_3387 : (!torch.int, !torch.int) -> !torch.list - %2898 = torch.aten.repeat %2892, %2897 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_3388 = torch.constant.int 32 - %2899 = torch.aten.mul.Scalar %2888, %int32_3388 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3389 = torch.constant.int 1 - %2900 = torch.aten.add.Tensor %2899, %2896, %int1_3389 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_3390 = torch.constant.int 2 - %2901 = torch.aten.mul.Scalar %2900, %int2_3390 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3391 = torch.constant.int 1 - %2902 = torch.aten.add.Tensor %2901, %2898, %int1_3391 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_3392 = torch.constant.int 32 - %2903 = torch.aten.mul.Scalar %2902, %int32_3392 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3393 = torch.constant.int 1 - %2904 = torch.aten.add.Tensor %2903, %2890, %int1_3393 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_3394 = torch.constant.int 32 - %int2_3395 = torch.constant.int 2 + %3388 = torch.aten.add.Tensor %3387, %3371, %int1_3381 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_3382 = torch.constant.int 5 + %3389 = torch.prims.convert_element_type %3363, %int5_3382 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_3383 = torch.constant.int 32 + %int2_3384 = torch.constant.int 2 + %int8_3385 = torch.constant.int 8 + %int32_3386 = torch.constant.int 32 + %int128_3387 = torch.constant.int 128 + %3390 = torch.prim.ListConstruct %456, %int32_3383, %int2_3384, %int8_3385, %int32_3386, %int128_3387 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3391 = torch.aten.view %3211, %3390 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3391, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_3388 = torch.constant.int 128 + %3392 = torch.prim.ListConstruct %596, %int128_3388 : (!torch.int, !torch.int) -> !torch.list + %3393 = torch.aten.view %3391, %3392 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3393, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %3394 = torch.prim.ListConstruct %3388 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_3389 = torch.constant.bool false + %3395 = torch.aten.index_put %3393, %3394, %3389, %false_3389 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3395, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_3390 = torch.constant.int 32 + %int2_3391 = torch.constant.int 2 + %int8_3392 = torch.constant.int 8 + %int32_3393 = torch.constant.int 32 + %int128_3394 = torch.constant.int 128 + %3396 = torch.prim.ListConstruct %456, %int32_3390, %int2_3391, %int8_3392, %int32_3393, %int128_3394 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3397 = torch.aten.view %3395, %3396 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3397, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_3395 = torch.constant.int 2097152 + %3398 = torch.prim.ListConstruct %456, %int2097152_3395 : (!torch.int, !torch.int) -> !torch.list + %3399 = torch.aten.view %3397, %3398 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %3399, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> %int32_3396 = torch.constant.int 32 - %int8_3397 = torch.constant.int 8 - %int128_3398 = torch.constant.int 128 - %2905 = torch.prim.ListConstruct %437, %int32_3394, %int2_3395, %int32_3396, %int8_3397, %int128_3398 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2906 = torch.aten.view %2742, %2905 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2906, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> + %int2_3397 = torch.constant.int 2 + %int8_3398 = torch.constant.int 8 %int32_3399 = torch.constant.int 32 - %2907 = torch.aten.mul.int %437, %int32_3399 : !torch.int, !torch.int -> !torch.int - %int2_3400 = torch.constant.int 2 - %2908 = torch.aten.mul.int %2907, %int2_3400 : !torch.int, !torch.int -> !torch.int - %int32_3401 = torch.constant.int 32 - %2909 = torch.aten.mul.int %2908, %int32_3401 : !torch.int, !torch.int -> !torch.int - %int8_3402 = torch.constant.int 8 - %int128_3403 = torch.constant.int 128 - %2910 = torch.prim.ListConstruct %2909, %int8_3402, %int128_3403 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2911 = torch.aten.view %2906, %2910 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2911, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %2912 = torch.prim.ListConstruct %2904 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_3404 = torch.constant.bool false - %2913 = torch.aten.index_put %2911, %2912, %2885, %false_3404 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2913, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_3405 = torch.constant.int 32 - %int2_3406 = torch.constant.int 2 - %int32_3407 = torch.constant.int 32 - %int8_3408 = torch.constant.int 8 - %int128_3409 = torch.constant.int 128 - %2914 = torch.prim.ListConstruct %437, %int32_3405, %int2_3406, %int32_3407, %int8_3408, %int128_3409 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2915 = torch.aten.view %2913, %2914 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2915, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3410 = torch.constant.int 2097152 - %2916 = torch.prim.ListConstruct %437, %int2097152_3410 : (!torch.int, !torch.int) -> !torch.list - %2917 = torch.aten.view %2915, %2916 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2917, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_3411 = torch.constant.int 32 - %int2_3412 = torch.constant.int 2 + %int128_3400 = torch.constant.int 128 + %3400 = torch.prim.ListConstruct %456, %int32_3396, %int2_3397, %int8_3398, %int32_3399, %int128_3400 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3401 = torch.aten.view %3399, %3400 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3401, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_3401 = torch.constant.int 128 + %3402 = torch.prim.ListConstruct %596, %int128_3401 : (!torch.int, !torch.int) -> !torch.list + %3403 = torch.aten.view %3401, %3402 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3403, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_3402 = torch.constant.none + %3404 = torch.aten.clone %189, %none_3402 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3405 = torch.aten.detach %3404 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3406 = torch.aten.detach %3405 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3407 = torch.aten.detach %3406 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_3403 = torch.constant.int 1 + %int1_3404 = torch.constant.int 1 + %int1_3405 = torch.constant.int 1 + %3408 = torch.prim.ListConstruct %int1_3403, %int1_3404, %int1_3405 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3409 = torch.aten.view %3407, %3408 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_3406 = torch.constant.int 32 + %3410 = torch.aten.mul.Scalar %3368, %int32_3406 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int13_3407 = torch.constant.int 13 + %int1_3408 = torch.constant.int 1 + %3411 = torch.aten.add.Scalar %3410, %int13_3407, %int1_3408 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_3409 = torch.constant.int 2 + %3412 = torch.aten.mul.Scalar %3411, %int2_3409 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_3410 = torch.constant.int 1 + %3413 = torch.aten.add.Tensor %3412, %3409, %int1_3410 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_3411 = torch.constant.int 8 + %3414 = torch.aten.mul.Scalar %3413, %int8_3411 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_3412 = torch.constant.int 1 + %3415 = torch.aten.add.Tensor %3414, %3374, %int1_3412 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int32_3413 = torch.constant.int 32 - %int8_3414 = torch.constant.int 8 - %int128_3415 = torch.constant.int 128 - %2918 = torch.prim.ListConstruct %437, %int32_3411, %int2_3412, %int32_3413, %int8_3414, %int128_3415 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2919 = torch.aten.view %2917, %2918 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2919, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_3416 = torch.constant.int 8 - %int128_3417 = torch.constant.int 128 - %2920 = torch.prim.ListConstruct %2909, %int8_3416, %int128_3417 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %2921 = torch.aten.view %2919, %2920 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2921, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_3418 = torch.constant.int 32 - %2922 = torch.aten.floor_divide.Scalar %arg2, %int32_3418 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_3419 = torch.constant.int 1 - %2923 = torch.aten.unsqueeze %2922, %int1_3419 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3420 = torch.constant.int 1 - %false_3421 = torch.constant.bool false - %2924 = torch.aten.gather %arg3, %int1_3420, %2923, %false_3421 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_3422 = torch.constant.int 32 - %2925 = torch.aten.remainder.Scalar %arg2, %int32_3422 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_3423 = torch.constant.int 1 - %2926 = torch.aten.unsqueeze %2925, %int1_3423 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %3416 = torch.aten.mul.Scalar %3415, %int32_3413 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_3414 = torch.constant.int 1 + %3417 = torch.aten.add.Tensor %3416, %3371, %int1_3414 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_3415 = torch.constant.int 5 + %3418 = torch.prims.convert_element_type %3343, %int5_3415 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %3419 = torch.prim.ListConstruct %3417 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_3416 = torch.constant.bool false + %3420 = torch.aten.index_put %3403, %3419, %3418, %false_3416 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3420, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_3417 = torch.constant.int 32 + %int2_3418 = torch.constant.int 2 + %int8_3419 = torch.constant.int 8 + %int32_3420 = torch.constant.int 32 + %int128_3421 = torch.constant.int 128 + %3421 = torch.prim.ListConstruct %456, %int32_3417, %int2_3418, %int8_3419, %int32_3420, %int128_3421 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3422 = torch.aten.view %3420, %3421 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3422, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_3422 = torch.constant.int 2097152 + %3423 = torch.prim.ListConstruct %456, %int2097152_3422 : (!torch.int, !torch.int) -> !torch.list + %3424 = torch.aten.view %3422, %3423 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %3424, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_3423 = torch.constant.none + %3425 = torch.aten.clone %190, %none_3423 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3426 = torch.aten.detach %3425 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3427 = torch.aten.detach %3426 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3428 = torch.aten.detach %3427 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %none_3424 = torch.constant.none - %2927 = torch.aten.clone %138, %none_3424 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_3425 = torch.constant.int 0 - %2928 = torch.aten.unsqueeze %2927, %int0_3425 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_3426 = torch.constant.int 4 - %int1_3427 = torch.constant.int 1 - %2929 = torch.prim.ListConstruct %int4_3426, %int1_3427 : (!torch.int, !torch.int) -> !torch.list - %int1_3428 = torch.constant.int 1 - %int1_3429 = torch.constant.int 1 - %2930 = torch.prim.ListConstruct %int1_3428, %int1_3429 : (!torch.int, !torch.int) -> !torch.list - %int4_3430 = torch.constant.int 4 - %int0_3431 = torch.constant.int 0 - %cpu_3432 = torch.constant.device "cpu" - %false_3433 = torch.constant.bool false - %2931 = torch.aten.empty_strided %2929, %2930, %int4_3430, %int0_3431, %cpu_3432, %false_3433 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int12_3434 = torch.constant.int 12 - %2932 = torch.aten.fill.Scalar %2931, %int12_3434 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_3435 = torch.constant.int 4 - %int1_3436 = torch.constant.int 1 - %2933 = torch.prim.ListConstruct %int4_3435, %int1_3436 : (!torch.int, !torch.int) -> !torch.list - %2934 = torch.aten.repeat %2928, %2933 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_3437 = torch.constant.int 32 - %2935 = torch.aten.mul.Scalar %2924, %int32_3437 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3438 = torch.constant.int 1 - %2936 = torch.aten.add.Tensor %2935, %2932, %int1_3438 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_3439 = torch.constant.int 2 - %2937 = torch.aten.mul.Scalar %2936, %int2_3439 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3440 = torch.constant.int 1 - %2938 = torch.aten.add.Tensor %2937, %2934, %int1_3440 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_3441 = torch.constant.int 32 - %2939 = torch.aten.mul.Scalar %2938, %int32_3441 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3442 = torch.constant.int 1 - %2940 = torch.aten.add.Tensor %2939, %2926, %int1_3442 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %2941 = torch.prim.ListConstruct %2940 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_3443 = torch.constant.bool false - %2942 = torch.aten.index_put %2921, %2941, %2873, %false_3443 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %2942, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_3444 = torch.constant.int 32 - %int2_3445 = torch.constant.int 2 - %int32_3446 = torch.constant.int 32 - %int8_3447 = torch.constant.int 8 - %int128_3448 = torch.constant.int 128 - %2943 = torch.prim.ListConstruct %437, %int32_3444, %int2_3445, %int32_3446, %int8_3447, %int128_3448 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2944 = torch.aten.view %2942, %2943 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2944, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3449 = torch.constant.int 2097152 - %2945 = torch.prim.ListConstruct %437, %int2097152_3449 : (!torch.int, !torch.int) -> !torch.list - %2946 = torch.aten.view %2944, %2945 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %2946, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %3429 = torch.aten.clone %191, %none_3424 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3430 = torch.aten.detach %3429 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3431 = torch.aten.detach %3430 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3432 = torch.aten.detach %3431 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_3425 = torch.constant.none + %3433 = torch.aten.clone %192, %none_3425 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3434 = torch.aten.detach %3433 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3435 = torch.aten.detach %3434 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3436 = torch.aten.detach %3435 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_3426 = torch.constant.int 32 + %int2_3427 = torch.constant.int 2 + %int8_3428 = torch.constant.int 8 + %int32_3429 = torch.constant.int 32 + %int128_3430 = torch.constant.int 128 + %3437 = torch.prim.ListConstruct %456, %int32_3426, %int2_3427, %int8_3428, %int32_3429, %int128_3430 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3438 = torch.aten.view %3424, %3437 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3438, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %3439 = torch_c.to_builtin_tensor %3438 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %3440 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_3431 = tensor.cast %3440 : tensor<4x?xi64> to tensor + %3441 = torch_c.to_builtin_tensor %3428 : !torch.vtensor<[],si64> -> tensor + %3442 = torch_c.to_builtin_tensor %3432 : !torch.vtensor<[],si64> -> tensor + %3443 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%3439, %cast_3431, %3441, %3442) : (tensor, tensor, tensor, tensor) -> tensor + %cast_3432 = tensor.cast %3443 : tensor to tensor<4x?x8x32x128xf16> + %3444 = torch_c.from_builtin_tensor %cast_3432 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %3444, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %3445 = torch_c.to_builtin_tensor %3438 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %3446 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_3433 = tensor.cast %3446 : tensor<4x?xi64> to tensor + %3447 = torch_c.to_builtin_tensor %3428 : !torch.vtensor<[],si64> -> tensor + %3448 = torch_c.to_builtin_tensor %3436 : !torch.vtensor<[],si64> -> tensor + %3449 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%3445, %cast_3433, %3447, %3448) : (tensor, tensor, tensor, tensor) -> tensor + %cast_3434 = tensor.cast %3449 : tensor to tensor<4x?x8x32x128xf16> + %3450 = torch_c.from_builtin_tensor %cast_3434 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %3450, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_3435 = torch.constant.int 2 + %int3_3436 = torch.constant.int 3 + %3451 = torch.aten.transpose.int %3444, %int2_3435, %int3_3436 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3451, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_3437 = torch.constant.int 0 + %3452 = torch.aten.clone %3451, %int0_3437 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3452, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_3438 = torch.constant.int 4 + %int8_3439 = torch.constant.int 8 + %int128_3440 = torch.constant.int 128 + %3453 = torch.prim.ListConstruct %int4_3438, %457, %int8_3439, %int128_3440 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3454 = torch.aten._unsafe_view %3452, %3453 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3454, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_3441 = torch.constant.int 2 + %int3_3442 = torch.constant.int 3 + %3455 = torch.aten.transpose.int %3450, %int2_3441, %int3_3442 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3455, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_3443 = torch.constant.int 0 + %3456 = torch.aten.clone %3455, %int0_3443 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3456, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_3444 = torch.constant.int 4 + %int8_3445 = torch.constant.int 8 + %int128_3446 = torch.constant.int 128 + %3457 = torch.prim.ListConstruct %int4_3444, %457, %int8_3445, %int128_3446 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3458 = torch.aten._unsafe_view %3456, %3457 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3458, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_3447 = torch.constant.int -2 + %3459 = torch.aten.unsqueeze %3454, %int-2_3447 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3459, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_3448 = torch.constant.int 4 + %int8_3449 = torch.constant.int 8 %int4_3450 = torch.constant.int 4 - %2947 = torch.prim.ListConstruct %int4_3450, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_3451 = torch.constant.int 1 - %2948 = torch.prim.ListConstruct %358, %int1_3451 : (!torch.int, !torch.int) -> !torch.list - %int4_3452 = torch.constant.int 4 + %int128_3451 = torch.constant.int 128 + %3460 = torch.prim.ListConstruct %int4_3448, %457, %int8_3449, %int4_3450, %int128_3451 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_3452 = torch.constant.bool false + %3461 = torch.aten.expand %3459, %3460, %false_3452 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3461, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int0_3453 = torch.constant.int 0 - %cpu_3454 = torch.constant.device "cpu" - %false_3455 = torch.constant.bool false - %2949 = torch.aten.empty_strided %2947, %2948, %int4_3452, %int0_3453, %cpu_3454, %false_3455 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2949, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int12_3456 = torch.constant.int 12 - %2950 = torch.aten.fill.Scalar %2949, %int12_3456 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2950, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_3457 = torch.constant.int 32 - %2951 = torch.aten.mul.Scalar %arg3, %int32_3457 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2951, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_3458 = torch.constant.int 1 - %2952 = torch.aten.add.Tensor %2951, %2950, %int1_3458 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %2952, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_3459 = torch.constant.int 4 - %2953 = torch.aten.mul.int %int4_3459, %358 : !torch.int, !torch.int -> !torch.int - %2954 = torch.prim.ListConstruct %2953 : (!torch.int) -> !torch.list - %2955 = torch.aten.view %2952, %2954 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %2955, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_3460 = torch.constant.int 32 - %int2_3461 = torch.constant.int 2 - %int32_3462 = torch.constant.int 32 - %int8_3463 = torch.constant.int 8 - %int128_3464 = torch.constant.int 128 - %2956 = torch.prim.ListConstruct %437, %int32_3460, %int2_3461, %int32_3462, %int8_3463, %int128_3464 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2957 = torch.aten.view %2946, %2956 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %2957, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> + %3462 = torch.aten.clone %3461, %int0_3453 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3462, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_3454 = torch.constant.int 4 + %int32_3455 = torch.constant.int 32 + %int128_3456 = torch.constant.int 128 + %3463 = torch.prim.ListConstruct %int4_3454, %457, %int32_3455, %int128_3456 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3464 = torch.aten._unsafe_view %3462, %3463 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3464, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_3457 = torch.constant.int -2 + %3465 = torch.aten.unsqueeze %3458, %int-2_3457 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3465, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_3458 = torch.constant.int 4 + %int8_3459 = torch.constant.int 8 + %int4_3460 = torch.constant.int 4 + %int128_3461 = torch.constant.int 128 + %3466 = torch.prim.ListConstruct %int4_3458, %457, %int8_3459, %int4_3460, %int128_3461 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_3462 = torch.constant.bool false + %3467 = torch.aten.expand %3465, %3466, %false_3462 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3467, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_3463 = torch.constant.int 0 + %3468 = torch.aten.clone %3467, %int0_3463 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3468, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_3464 = torch.constant.int 4 %int32_3465 = torch.constant.int 32 - %2958 = torch.aten.mul.int %437, %int32_3465 : !torch.int, !torch.int -> !torch.int - %int2_3466 = torch.constant.int 2 - %int32_3467 = torch.constant.int 32 - %int8_3468 = torch.constant.int 8 - %int128_3469 = torch.constant.int 128 - %2959 = torch.prim.ListConstruct %2958, %int2_3466, %int32_3467, %int8_3468, %int128_3469 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2960 = torch.aten.view %2957, %2959 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %2960, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_3470 = torch.constant.int 0 - %2961 = torch.aten.index_select %2960, %int0_3470, %2955 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %2961, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_3471 = torch.constant.int 4 + %int128_3466 = torch.constant.int 128 + %3469 = torch.prim.ListConstruct %int4_3464, %457, %int32_3465, %int128_3466 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3470 = torch.aten._unsafe_view %3468, %3469 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3470, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_3467 = torch.constant.int 1 + %int2_3468 = torch.constant.int 2 + %3471 = torch.aten.transpose.int %3353, %int1_3467, %int2_3468 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_3469 = torch.constant.int 1 + %int2_3470 = torch.constant.int 2 + %3472 = torch.aten.transpose.int %3464, %int1_3469, %int2_3470 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3472, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_3471 = torch.constant.int 1 %int2_3472 = torch.constant.int 2 - %int32_3473 = torch.constant.int 32 - %int8_3474 = torch.constant.int 8 - %int128_3475 = torch.constant.int 128 - %2962 = torch.prim.ListConstruct %int4_3471, %358, %int2_3472, %int32_3473, %int8_3474, %int128_3475 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2963 = torch.aten.view %2961, %2962 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2963, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_3476 = torch.constant.int 0 - %int0_3477 = torch.constant.int 0 - %int9223372036854775807_3478 = torch.constant.int 9223372036854775807 + %3473 = torch.aten.transpose.int %3470, %int1_3471, %int2_3472 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3473, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_3473 = torch.constant.float 0.000000e+00 + %false_3474 = torch.constant.bool false + %none_3475 = torch.constant.none + %3474:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3471, %3472, %3473, %float0.000000e00_3473, %false_3474, %470, %none_3475) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_3476 = torch.constant.int 1 + %int2_3477 = torch.constant.int 2 + %3475 = torch.aten.transpose.int %3474#0, %int1_3476, %int2_3477 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_3478 = torch.constant.int 4 %int1_3479 = torch.constant.int 1 - %2964 = torch.aten.slice.Tensor %2963, %int0_3476, %int0_3477, %int9223372036854775807_3478, %int1_3479 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2964, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_3480 = torch.constant.int 1 - %int0_3481 = torch.constant.int 0 - %int9223372036854775807_3482 = torch.constant.int 9223372036854775807 - %int1_3483 = torch.constant.int 1 - %2965 = torch.aten.slice.Tensor %2964, %int1_3480, %int0_3481, %int9223372036854775807_3482, %int1_3483 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2965, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_3484 = torch.constant.int 2 - %int0_3485 = torch.constant.int 0 - %2966 = torch.aten.select.int %2965, %int2_3484, %int0_3485 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2966, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_3486 = torch.constant.int 32 - %2967 = torch.aten.mul.int %358, %int32_3486 : !torch.int, !torch.int -> !torch.int - %int2_3487 = torch.constant.int 2 - %int0_3488 = torch.constant.int 0 + %int4096_3480 = torch.constant.int 4096 + %3476 = torch.prim.ListConstruct %int4_3478, %int1_3479, %int4096_3480 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3477 = torch.aten.view %3475, %3476 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_3481 = torch.constant.int -2 + %int-1_3482 = torch.constant.int -1 + %3478 = torch.aten.transpose.int %193, %int-2_3481, %int-1_3482 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_3483 = torch.constant.int 5 + %3479 = torch.prims.convert_element_type %3478, %int5_3483 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_3484 = torch.constant.int 4 + %int4096_3485 = torch.constant.int 4096 + %3480 = torch.prim.ListConstruct %int4_3484, %int4096_3485 : (!torch.int, !torch.int) -> !torch.list + %3481 = torch.aten.view %3477, %3480 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3482 = torch.aten.mm %3481, %3479 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_3486 = torch.constant.int 4 + %int1_3487 = torch.constant.int 1 + %int4096_3488 = torch.constant.int 4096 + %3483 = torch.prim.ListConstruct %int4_3486, %int1_3487, %int4096_3488 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3484 = torch.aten.view %3482, %3483 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_3489 = torch.constant.int 1 - %2968 = torch.aten.slice.Tensor %2966, %int2_3487, %int0_3488, %2967, %int1_3489 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2968, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_3490 = torch.constant.int 0 - %2969 = torch.aten.clone %2968, %int0_3490 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2969, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_3491 = torch.constant.int 1 - %2970 = torch.aten.size.int %2965, %int1_3491 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_3492 = torch.constant.int 32 - %2971 = torch.aten.mul.int %2970, %int32_3492 : !torch.int, !torch.int -> !torch.int - %int4_3493 = torch.constant.int 4 - %int8_3494 = torch.constant.int 8 - %int128_3495 = torch.constant.int 128 - %2972 = torch.prim.ListConstruct %int4_3493, %2971, %int8_3494, %int128_3495 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2973 = torch.aten._unsafe_view %2969, %2972 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2973, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_3496 = torch.constant.int 0 - %int0_3497 = torch.constant.int 0 - %int9223372036854775807_3498 = torch.constant.int 9223372036854775807 - %int1_3499 = torch.constant.int 1 - %2974 = torch.aten.slice.Tensor %2973, %int0_3496, %int0_3497, %int9223372036854775807_3498, %int1_3499 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2974, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_3500 = torch.constant.int 0 - %int0_3501 = torch.constant.int 0 - %int9223372036854775807_3502 = torch.constant.int 9223372036854775807 - %int1_3503 = torch.constant.int 1 - %2975 = torch.aten.slice.Tensor %2963, %int0_3500, %int0_3501, %int9223372036854775807_3502, %int1_3503 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2975, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_3504 = torch.constant.int 1 - %int0_3505 = torch.constant.int 0 - %int9223372036854775807_3506 = torch.constant.int 9223372036854775807 - %int1_3507 = torch.constant.int 1 - %2976 = torch.aten.slice.Tensor %2975, %int1_3504, %int0_3505, %int9223372036854775807_3506, %int1_3507 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %2976, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_3508 = torch.constant.int 2 - %int1_3509 = torch.constant.int 1 - %2977 = torch.aten.select.int %2976, %int2_3508, %int1_3509 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2977, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_3510 = torch.constant.int 2 - %int0_3511 = torch.constant.int 0 - %int1_3512 = torch.constant.int 1 - %2978 = torch.aten.slice.Tensor %2977, %int2_3510, %int0_3511, %2967, %int1_3512 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2978, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_3513 = torch.constant.int 0 - %2979 = torch.aten.clone %2978, %int0_3513 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %2979, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_3514 = torch.constant.int 1 - %2980 = torch.aten.size.int %2976, %int1_3514 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_3515 = torch.constant.int 32 - %2981 = torch.aten.mul.int %2980, %int32_3515 : !torch.int, !torch.int -> !torch.int - %int4_3516 = torch.constant.int 4 - %int8_3517 = torch.constant.int 8 - %int128_3518 = torch.constant.int 128 - %2982 = torch.prim.ListConstruct %int4_3516, %2981, %int8_3517, %int128_3518 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2983 = torch.aten._unsafe_view %2979, %2982 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2983, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_3519 = torch.constant.int 0 - %int0_3520 = torch.constant.int 0 - %int9223372036854775807_3521 = torch.constant.int 9223372036854775807 - %int1_3522 = torch.constant.int 1 - %2984 = torch.aten.slice.Tensor %2983, %int0_3519, %int0_3520, %int9223372036854775807_3521, %int1_3522 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %2984, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_3523 = torch.constant.int -2 - %2985 = torch.aten.unsqueeze %2974, %int-2_3523 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2985, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_3524 = torch.constant.int 1 - %2986 = torch.aten.size.int %2973, %int1_3524 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_3525 = torch.constant.int 4 - %int8_3526 = torch.constant.int 8 - %int4_3527 = torch.constant.int 4 - %int128_3528 = torch.constant.int 128 - %2987 = torch.prim.ListConstruct %int4_3525, %2986, %int8_3526, %int4_3527, %int128_3528 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_3529 = torch.constant.bool false - %2988 = torch.aten.expand %2985, %2987, %false_3529 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2988, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_3530 = torch.constant.int 0 - %2989 = torch.aten.clone %2988, %int0_3530 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2989, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_3531 = torch.constant.int 4 - %int32_3532 = torch.constant.int 32 - %int128_3533 = torch.constant.int 128 - %2990 = torch.prim.ListConstruct %int4_3531, %2986, %int32_3532, %int128_3533 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2991 = torch.aten._unsafe_view %2989, %2990 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2991, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_3534 = torch.constant.int -2 - %2992 = torch.aten.unsqueeze %2984, %int-2_3534 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %2992, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_3535 = torch.constant.int 1 - %2993 = torch.aten.size.int %2983, %int1_3535 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int + %3485 = torch.aten.add.Tensor %3306, %3484, %int1_3489 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_3490 = torch.constant.int 6 + %3486 = torch.prims.convert_element_type %3485, %int6_3490 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_3491 = torch.constant.int 2 + %3487 = torch.aten.pow.Tensor_Scalar %3486, %int2_3491 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_3492 = torch.constant.int -1 + %3488 = torch.prim.ListConstruct %int-1_3492 : (!torch.int) -> !torch.list + %true_3493 = torch.constant.bool true + %none_3494 = torch.constant.none + %3489 = torch.aten.mean.dim %3487, %3488, %true_3493, %none_3494 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_3495 = torch.constant.float 9.9999997473787516E-6 + %int1_3496 = torch.constant.int 1 + %3490 = torch.aten.add.Scalar %3489, %float9.999990e-06_3495, %int1_3496 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %3491 = torch.aten.rsqrt %3490 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %3492 = torch.aten.mul.Tensor %3486, %3491 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_3497 = torch.constant.int 5 + %3493 = torch.prims.convert_element_type %3492, %int5_3497 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %3494 = torch.aten.mul.Tensor %194, %3493 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_3498 = torch.constant.int 5 + %3495 = torch.prims.convert_element_type %3494, %int5_3498 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_3499 = torch.constant.int -2 + %int-1_3500 = torch.constant.int -1 + %3496 = torch.aten.transpose.int %195, %int-2_3499, %int-1_3500 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3501 = torch.constant.int 5 + %3497 = torch.prims.convert_element_type %3496, %int5_3501 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_3502 = torch.constant.int 4 + %int4096_3503 = torch.constant.int 4096 + %3498 = torch.prim.ListConstruct %int4_3502, %int4096_3503 : (!torch.int, !torch.int) -> !torch.list + %3499 = torch.aten.view %3495, %3498 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3500 = torch.aten.mm %3499, %3497 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_3504 = torch.constant.int 4 + %int1_3505 = torch.constant.int 1 + %int14336_3506 = torch.constant.int 14336 + %3501 = torch.prim.ListConstruct %int4_3504, %int1_3505, %int14336_3506 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3502 = torch.aten.view %3500, %3501 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %3503 = torch.aten.silu %3502 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_3507 = torch.constant.int -2 + %int-1_3508 = torch.constant.int -1 + %3504 = torch.aten.transpose.int %196, %int-2_3507, %int-1_3508 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3509 = torch.constant.int 5 + %3505 = torch.prims.convert_element_type %3504, %int5_3509 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_3510 = torch.constant.int 4 + %int4096_3511 = torch.constant.int 4096 + %3506 = torch.prim.ListConstruct %int4_3510, %int4096_3511 : (!torch.int, !torch.int) -> !torch.list + %3507 = torch.aten.view %3495, %3506 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3508 = torch.aten.mm %3507, %3505 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_3512 = torch.constant.int 4 + %int1_3513 = torch.constant.int 1 + %int14336_3514 = torch.constant.int 14336 + %3509 = torch.prim.ListConstruct %int4_3512, %int1_3513, %int14336_3514 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3510 = torch.aten.view %3508, %3509 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %3511 = torch.aten.mul.Tensor %3503, %3510 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_3515 = torch.constant.int -2 + %int-1_3516 = torch.constant.int -1 + %3512 = torch.aten.transpose.int %197, %int-2_3515, %int-1_3516 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_3517 = torch.constant.int 5 + %3513 = torch.prims.convert_element_type %3512, %int5_3517 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_3518 = torch.constant.int 4 + %int14336_3519 = torch.constant.int 14336 + %3514 = torch.prim.ListConstruct %int4_3518, %int14336_3519 : (!torch.int, !torch.int) -> !torch.list + %3515 = torch.aten.view %3511, %3514 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %3516 = torch.aten.mm %3515, %3513 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_3520 = torch.constant.int 4 + %int1_3521 = torch.constant.int 1 + %int4096_3522 = torch.constant.int 4096 + %3517 = torch.prim.ListConstruct %int4_3520, %int1_3521, %int4096_3522 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3518 = torch.aten.view %3516, %3517 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_3523 = torch.constant.int 1 + %3519 = torch.aten.add.Tensor %3485, %3518, %int1_3523 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_3524 = torch.constant.int 6 + %3520 = torch.prims.convert_element_type %3519, %int6_3524 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_3525 = torch.constant.int 2 + %3521 = torch.aten.pow.Tensor_Scalar %3520, %int2_3525 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_3526 = torch.constant.int -1 + %3522 = torch.prim.ListConstruct %int-1_3526 : (!torch.int) -> !torch.list + %true_3527 = torch.constant.bool true + %none_3528 = torch.constant.none + %3523 = torch.aten.mean.dim %3521, %3522, %true_3527, %none_3528 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_3529 = torch.constant.float 9.9999997473787516E-6 + %int1_3530 = torch.constant.int 1 + %3524 = torch.aten.add.Scalar %3523, %float9.999990e-06_3529, %int1_3530 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %3525 = torch.aten.rsqrt %3524 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %3526 = torch.aten.mul.Tensor %3520, %3525 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_3531 = torch.constant.int 5 + %3527 = torch.prims.convert_element_type %3526, %int5_3531 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %3528 = torch.aten.mul.Tensor %198, %3527 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_3532 = torch.constant.int 5 + %3529 = torch.prims.convert_element_type %3528, %int5_3532 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_3533 = torch.constant.int -2 + %int-1_3534 = torch.constant.int -1 + %3530 = torch.aten.transpose.int %199, %int-2_3533, %int-1_3534 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_3535 = torch.constant.int 5 + %3531 = torch.prims.convert_element_type %3530, %int5_3535 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> %int4_3536 = torch.constant.int 4 - %int8_3537 = torch.constant.int 8 + %int4096_3537 = torch.constant.int 4096 + %3532 = torch.prim.ListConstruct %int4_3536, %int4096_3537 : (!torch.int, !torch.int) -> !torch.list + %3533 = torch.aten.view %3529, %3532 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3534 = torch.aten.mm %3533, %3531 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_3538 = torch.constant.int 4 - %int128_3539 = torch.constant.int 128 - %2994 = torch.prim.ListConstruct %int4_3536, %2993, %int8_3537, %int4_3538, %int128_3539 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_3540 = torch.constant.bool false - %2995 = torch.aten.expand %2992, %2994, %false_3540 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2995, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_3541 = torch.constant.int 0 - %2996 = torch.aten.clone %2995, %int0_3541 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %2996, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_3542 = torch.constant.int 4 - %int32_3543 = torch.constant.int 32 - %int128_3544 = torch.constant.int 128 - %2997 = torch.prim.ListConstruct %int4_3542, %2993, %int32_3543, %int128_3544 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %2998 = torch.aten._unsafe_view %2996, %2997 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %2998, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_3545 = torch.constant.int 1 - %int2_3546 = torch.constant.int 2 - %2999 = torch.aten.transpose.int %2879, %int1_3545, %int2_3546 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_3539 = torch.constant.int 1 + %int4096_3540 = torch.constant.int 4096 + %3535 = torch.prim.ListConstruct %int4_3538, %int1_3539, %int4096_3540 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3536 = torch.aten.view %3534, %3535 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_3541 = torch.constant.int -2 + %int-1_3542 = torch.constant.int -1 + %3537 = torch.aten.transpose.int %200, %int-2_3541, %int-1_3542 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_3543 = torch.constant.int 5 + %3538 = torch.prims.convert_element_type %3537, %int5_3543 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_3544 = torch.constant.int 4 + %int4096_3545 = torch.constant.int 4096 + %3539 = torch.prim.ListConstruct %int4_3544, %int4096_3545 : (!torch.int, !torch.int) -> !torch.list + %3540 = torch.aten.view %3529, %3539 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3541 = torch.aten.mm %3540, %3538 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_3546 = torch.constant.int 4 %int1_3547 = torch.constant.int 1 - %int2_3548 = torch.constant.int 2 - %3000 = torch.aten.transpose.int %2991, %int1_3547, %int2_3548 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3000, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_3549 = torch.constant.int 1 - %int2_3550 = torch.constant.int 2 - %3001 = torch.aten.transpose.int %2998, %int1_3549, %int2_3550 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3001, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_3551 = torch.constant.float 0.000000e+00 - %false_3552 = torch.constant.bool false - %none_3553 = torch.constant.none - %3002:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%2999, %3000, %3001, %float0.000000e00_3551, %false_3552, %368, %none_3553) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_3554 = torch.constant.int 1 - %int2_3555 = torch.constant.int 2 - %3003 = torch.aten.transpose.int %3002#0, %int1_3554, %int2_3555 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_3556 = torch.constant.int 4 - %int1_3557 = torch.constant.int 1 - %int4096_3558 = torch.constant.int 4096 - %3004 = torch.prim.ListConstruct %int4_3556, %int1_3557, %int4096_3558 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3005 = torch.aten.view %3003, %3004 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_3559 = torch.constant.int -2 - %int-1_3560 = torch.constant.int -1 - %3006 = torch.aten.transpose.int %139, %int-2_3559, %int-1_3560 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int1024_3548 = torch.constant.int 1024 + %3542 = torch.prim.ListConstruct %int4_3546, %int1_3547, %int1024_3548 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3543 = torch.aten.view %3541, %3542 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_3549 = torch.constant.int -2 + %int-1_3550 = torch.constant.int -1 + %3544 = torch.aten.transpose.int %201, %int-2_3549, %int-1_3550 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_3551 = torch.constant.int 5 + %3545 = torch.prims.convert_element_type %3544, %int5_3551 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_3552 = torch.constant.int 4 + %int4096_3553 = torch.constant.int 4096 + %3546 = torch.prim.ListConstruct %int4_3552, %int4096_3553 : (!torch.int, !torch.int) -> !torch.list + %3547 = torch.aten.view %3529, %3546 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3548 = torch.aten.mm %3547, %3545 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_3554 = torch.constant.int 4 + %int1_3555 = torch.constant.int 1 + %int1024_3556 = torch.constant.int 1024 + %3549 = torch.prim.ListConstruct %int4_3554, %int1_3555, %int1024_3556 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3550 = torch.aten.view %3548, %3549 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_3557 = torch.constant.int 4 + %int1_3558 = torch.constant.int 1 + %int32_3559 = torch.constant.int 32 + %int128_3560 = torch.constant.int 128 + %3551 = torch.prim.ListConstruct %int4_3557, %int1_3558, %int32_3559, %int128_3560 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3552 = torch.aten.view %3536, %3551 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> %int4_3561 = torch.constant.int 4 - %int4096_3562 = torch.constant.int 4096 - %3007 = torch.prim.ListConstruct %int4_3561, %int4096_3562 : (!torch.int, !torch.int) -> !torch.list - %3008 = torch.aten.view %3005, %3007 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3009 = torch.aten.mm %3008, %3006 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_3563 = torch.constant.int 4 - %int1_3564 = torch.constant.int 1 - %int4096_3565 = torch.constant.int 4096 - %3010 = torch.prim.ListConstruct %int4_3563, %int1_3564, %int4096_3565 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3011 = torch.aten.view %3009, %3010 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_3562 = torch.constant.int 1 + %int8_3563 = torch.constant.int 8 + %int128_3564 = torch.constant.int 128 + %3553 = torch.prim.ListConstruct %int4_3561, %int1_3562, %int8_3563, %int128_3564 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3554 = torch.aten.view %3543, %3553 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_3565 = torch.constant.int 4 %int1_3566 = torch.constant.int 1 - %3012 = torch.aten.add.Tensor %2839, %3011, %int1_3566 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_3567 = torch.constant.int 6 - %3013 = torch.prims.convert_element_type %3012, %int6_3567 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_3568 = torch.constant.int 2 - %3014 = torch.aten.pow.Tensor_Scalar %3013, %int2_3568 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_3569 = torch.constant.int -1 - %3015 = torch.prim.ListConstruct %int-1_3569 : (!torch.int) -> !torch.list - %true_3570 = torch.constant.bool true - %none_3571 = torch.constant.none - %3016 = torch.aten.mean.dim %3014, %3015, %true_3570, %none_3571 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_3572 = torch.constant.float 9.9999997473787516E-6 - %int1_3573 = torch.constant.int 1 - %3017 = torch.aten.add.Scalar %3016, %float9.999990e-06_3572, %int1_3573 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %3018 = torch.aten.rsqrt %3017 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %3019 = torch.aten.mul.Tensor %3013, %3018 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_3574 = torch.constant.int 5 - %3020 = torch.prims.convert_element_type %3019, %int5_3574 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %3021 = torch.aten.mul.Tensor %140, %3020 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_3575 = torch.constant.int 5 - %3022 = torch.prims.convert_element_type %3021, %int5_3575 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_3576 = torch.constant.int -2 - %int-1_3577 = torch.constant.int -1 - %3023 = torch.aten.transpose.int %141, %int-2_3576, %int-1_3577 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3578 = torch.constant.int 4 - %int4096_3579 = torch.constant.int 4096 - %3024 = torch.prim.ListConstruct %int4_3578, %int4096_3579 : (!torch.int, !torch.int) -> !torch.list - %3025 = torch.aten.view %3022, %3024 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3026 = torch.aten.mm %3025, %3023 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_3580 = torch.constant.int 4 + %int8_3567 = torch.constant.int 8 + %int128_3568 = torch.constant.int 128 + %3555 = torch.prim.ListConstruct %int4_3565, %int1_3566, %int8_3567, %int128_3568 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3556 = torch.aten.view %3550, %3555 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_3569 = torch.constant.int 1 + %int2_3570 = torch.constant.int 2 + %3557 = torch.aten.transpose.int %3552, %int1_3569, %int2_3570 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %3558 = torch.aten.mul.Tensor %3557, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_3571 = torch.constant.int 3 + %int0_3572 = torch.constant.int 0 + %int64_3573 = torch.constant.int 64 + %int1_3574 = torch.constant.int 1 + %3559 = torch.aten.slice.Tensor %3557, %int3_3571, %int0_3572, %int64_3573, %int1_3574 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_3575 = torch.constant.int 3 + %int64_3576 = torch.constant.int 64 + %int9223372036854775807_3577 = torch.constant.int 9223372036854775807 + %int1_3578 = torch.constant.int 1 + %3560 = torch.aten.slice.Tensor %3557, %int3_3575, %int64_3576, %int9223372036854775807_3577, %int1_3578 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %3561 = torch.aten.neg %3560 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %3562 = torch.prim.ListConstruct %3561, %3559 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_3579 = torch.constant.int -1 + %3563 = torch.aten.cat %3562, %int-1_3579 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %3564 = torch.aten.mul.Tensor %3563, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_3580 = torch.constant.int 1 + %3565 = torch.aten.add.Tensor %3558, %3564, %int1_3580 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_3581 = torch.constant.int 1 - %int14336_3582 = torch.constant.int 14336 - %3027 = torch.prim.ListConstruct %int4_3580, %int1_3581, %int14336_3582 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3028 = torch.aten.view %3026, %3027 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %3029 = torch.aten.silu %3028 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_3583 = torch.constant.int -2 - %int-1_3584 = torch.constant.int -1 - %3030 = torch.aten.transpose.int %142, %int-2_3583, %int-1_3584 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3585 = torch.constant.int 4 - %int4096_3586 = torch.constant.int 4096 - %3031 = torch.prim.ListConstruct %int4_3585, %int4096_3586 : (!torch.int, !torch.int) -> !torch.list - %3032 = torch.aten.view %3022, %3031 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3033 = torch.aten.mm %3032, %3030 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_3587 = torch.constant.int 4 + %int2_3582 = torch.constant.int 2 + %3566 = torch.aten.transpose.int %3565, %int1_3581, %int2_3582 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_3583 = torch.constant.int 1 + %int2_3584 = torch.constant.int 2 + %3567 = torch.aten.transpose.int %3554, %int1_3583, %int2_3584 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %3568 = torch.aten.mul.Tensor %3567, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_3585 = torch.constant.int 3 + %int0_3586 = torch.constant.int 0 + %int64_3587 = torch.constant.int 64 %int1_3588 = torch.constant.int 1 - %int14336_3589 = torch.constant.int 14336 - %3034 = torch.prim.ListConstruct %int4_3587, %int1_3588, %int14336_3589 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3035 = torch.aten.view %3033, %3034 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %3036 = torch.aten.mul.Tensor %3029, %3035 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_3590 = torch.constant.int -2 - %int-1_3591 = torch.constant.int -1 - %3037 = torch.aten.transpose.int %143, %int-2_3590, %int-1_3591 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_3592 = torch.constant.int 4 - %int14336_3593 = torch.constant.int 14336 - %3038 = torch.prim.ListConstruct %int4_3592, %int14336_3593 : (!torch.int, !torch.int) -> !torch.list - %3039 = torch.aten.view %3036, %3038 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %3040 = torch.aten.mm %3039, %3037 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_3594 = torch.constant.int 4 + %3569 = torch.aten.slice.Tensor %3567, %int3_3585, %int0_3586, %int64_3587, %int1_3588 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_3589 = torch.constant.int 3 + %int64_3590 = torch.constant.int 64 + %int9223372036854775807_3591 = torch.constant.int 9223372036854775807 + %int1_3592 = torch.constant.int 1 + %3570 = torch.aten.slice.Tensor %3567, %int3_3589, %int64_3590, %int9223372036854775807_3591, %int1_3592 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %3571 = torch.aten.neg %3570 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %3572 = torch.prim.ListConstruct %3571, %3569 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_3593 = torch.constant.int -1 + %3573 = torch.aten.cat %3572, %int-1_3593 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %3574 = torch.aten.mul.Tensor %3573, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_3594 = torch.constant.int 1 + %3575 = torch.aten.add.Tensor %3568, %3574, %int1_3594 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> %int1_3595 = torch.constant.int 1 - %int4096_3596 = torch.constant.int 4096 - %3041 = torch.prim.ListConstruct %int4_3594, %int1_3595, %int4096_3596 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3042 = torch.aten.view %3040, %3041 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_3597 = torch.constant.int 1 - %3043 = torch.aten.add.Tensor %3012, %3042, %int1_3597 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_3598 = torch.constant.int 6 - %3044 = torch.prims.convert_element_type %3043, %int6_3598 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_3599 = torch.constant.int 2 - %3045 = torch.aten.pow.Tensor_Scalar %3044, %int2_3599 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_3600 = torch.constant.int -1 - %3046 = torch.prim.ListConstruct %int-1_3600 : (!torch.int) -> !torch.list - %true_3601 = torch.constant.bool true - %none_3602 = torch.constant.none - %3047 = torch.aten.mean.dim %3045, %3046, %true_3601, %none_3602 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_3603 = torch.constant.float 9.9999997473787516E-6 - %int1_3604 = torch.constant.int 1 - %3048 = torch.aten.add.Scalar %3047, %float9.999990e-06_3603, %int1_3604 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %3049 = torch.aten.rsqrt %3048 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %3050 = torch.aten.mul.Tensor %3044, %3049 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_3605 = torch.constant.int 5 - %3051 = torch.prims.convert_element_type %3050, %int5_3605 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %3052 = torch.aten.mul.Tensor %144, %3051 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_3606 = torch.constant.int 5 - %3053 = torch.prims.convert_element_type %3052, %int5_3606 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_3607 = torch.constant.int -2 - %int-1_3608 = torch.constant.int -1 - %3054 = torch.aten.transpose.int %145, %int-2_3607, %int-1_3608 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3609 = torch.constant.int 4 - %int4096_3610 = torch.constant.int 4096 - %3055 = torch.prim.ListConstruct %int4_3609, %int4096_3610 : (!torch.int, !torch.int) -> !torch.list - %3056 = torch.aten.view %3053, %3055 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3057 = torch.aten.mm %3056, %3054 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_3611 = torch.constant.int 4 - %int1_3612 = torch.constant.int 1 - %int4096_3613 = torch.constant.int 4096 - %3058 = torch.prim.ListConstruct %int4_3611, %int1_3612, %int4096_3613 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3059 = torch.aten.view %3057, %3058 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_3614 = torch.constant.int -2 - %int-1_3615 = torch.constant.int -1 - %3060 = torch.aten.transpose.int %146, %int-2_3614, %int-1_3615 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3616 = torch.constant.int 4 - %int4096_3617 = torch.constant.int 4096 - %3061 = torch.prim.ListConstruct %int4_3616, %int4096_3617 : (!torch.int, !torch.int) -> !torch.list - %3062 = torch.aten.view %3053, %3061 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3063 = torch.aten.mm %3062, %3060 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_3618 = torch.constant.int 4 + %int2_3596 = torch.constant.int 2 + %3576 = torch.aten.transpose.int %3575, %int1_3595, %int2_3596 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_3597 = torch.constant.int 32 + %3577 = torch.aten.floor_divide.Scalar %arg2, %int32_3597 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_3598 = torch.constant.int 1 + %3578 = torch.aten.unsqueeze %3577, %int1_3598 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_3599 = torch.constant.int 1 + %false_3600 = torch.constant.bool false + %3579 = torch.aten.gather %arg3, %int1_3599, %3578, %false_3600 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_3601 = torch.constant.int 4 + %int1_3602 = torch.constant.int 1 + %int1_3603 = torch.constant.int 1 + %3580 = torch.prim.ListConstruct %int4_3601, %int1_3602, %int1_3603 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3581 = torch.aten.view %3579, %3580 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_3604 = torch.constant.int 32 + %3582 = torch.aten.remainder.Scalar %arg2, %int32_3604 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_3605 = torch.constant.int 4 + %int1_3606 = torch.constant.int 1 + %int1_3607 = torch.constant.int 1 + %3583 = torch.prim.ListConstruct %int4_3605, %int1_3606, %int1_3607 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3584 = torch.aten.view %3582, %3583 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_3608 = torch.constant.int 8 + %none_3609 = torch.constant.none + %none_3610 = torch.constant.none + %cpu_3611 = torch.constant.device "cpu" + %false_3612 = torch.constant.bool false + %3585 = torch.aten.arange %int8_3608, %none_3609, %none_3610, %cpu_3611, %false_3612 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_3613 = torch.constant.int 1 + %int1_3614 = torch.constant.int 1 + %int8_3615 = torch.constant.int 8 + %3586 = torch.prim.ListConstruct %int1_3613, %int1_3614, %int8_3615 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3587 = torch.aten.view %3585, %3586 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_3616 = torch.constant.none + %3588 = torch.aten.clone %202, %none_3616 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3589 = torch.aten.detach %3588 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3590 = torch.aten.detach %3589 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3591 = torch.aten.detach %3590 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_3617 = torch.constant.int 1 + %int1_3618 = torch.constant.int 1 %int1_3619 = torch.constant.int 1 - %int1024_3620 = torch.constant.int 1024 - %3064 = torch.prim.ListConstruct %int4_3618, %int1_3619, %int1024_3620 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3065 = torch.aten.view %3063, %3064 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_3621 = torch.constant.int -2 - %int-1_3622 = torch.constant.int -1 - %3066 = torch.aten.transpose.int %147, %int-2_3621, %int-1_3622 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3623 = torch.constant.int 4 - %int4096_3624 = torch.constant.int 4096 - %3067 = torch.prim.ListConstruct %int4_3623, %int4096_3624 : (!torch.int, !torch.int) -> !torch.list - %3068 = torch.aten.view %3053, %3067 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3069 = torch.aten.mm %3068, %3066 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_3625 = torch.constant.int 4 - %int1_3626 = torch.constant.int 1 - %int1024_3627 = torch.constant.int 1024 - %3070 = torch.prim.ListConstruct %int4_3625, %int1_3626, %int1024_3627 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3071 = torch.aten.view %3069, %3070 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_3628 = torch.constant.int 4 - %int1_3629 = torch.constant.int 1 - %int32_3630 = torch.constant.int 32 - %int128_3631 = torch.constant.int 128 - %3072 = torch.prim.ListConstruct %int4_3628, %int1_3629, %int32_3630, %int128_3631 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3073 = torch.aten.view %3059, %3072 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_3632 = torch.constant.int 4 - %int1_3633 = torch.constant.int 1 - %int8_3634 = torch.constant.int 8 - %int128_3635 = torch.constant.int 128 - %3074 = torch.prim.ListConstruct %int4_3632, %int1_3633, %int8_3634, %int128_3635 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3075 = torch.aten.view %3065, %3074 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_3636 = torch.constant.int 4 - %int1_3637 = torch.constant.int 1 + %3592 = torch.prim.ListConstruct %int1_3617, %int1_3618, %int1_3619 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3593 = torch.aten.view %3591, %3592 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_3620 = torch.constant.int 32 + %3594 = torch.aten.mul.Scalar %3581, %int32_3620 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int14 = torch.constant.int 14 + %int1_3621 = torch.constant.int 1 + %3595 = torch.aten.add.Scalar %3594, %int14, %int1_3621 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_3622 = torch.constant.int 2 + %3596 = torch.aten.mul.Scalar %3595, %int2_3622 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_3623 = torch.constant.int 1 + %3597 = torch.aten.add.Tensor %3596, %3593, %int1_3623 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_3624 = torch.constant.int 8 + %3598 = torch.aten.mul.Scalar %3597, %int8_3624 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_3625 = torch.constant.int 1 + %3599 = torch.aten.add.Tensor %3598, %3587, %int1_3625 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_3626 = torch.constant.int 32 + %3600 = torch.aten.mul.Scalar %3599, %int32_3626 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_3627 = torch.constant.int 1 + %3601 = torch.aten.add.Tensor %3600, %3584, %int1_3627 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_3628 = torch.constant.int 5 + %3602 = torch.prims.convert_element_type %3576, %int5_3628 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_3629 = torch.constant.int 32 + %int2_3630 = torch.constant.int 2 + %int8_3631 = torch.constant.int 8 + %int32_3632 = torch.constant.int 32 + %int128_3633 = torch.constant.int 128 + %3603 = torch.prim.ListConstruct %456, %int32_3629, %int2_3630, %int8_3631, %int32_3632, %int128_3633 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3604 = torch.aten.view %3424, %3603 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3604, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_3634 = torch.constant.int 128 + %3605 = torch.prim.ListConstruct %596, %int128_3634 : (!torch.int, !torch.int) -> !torch.list + %3606 = torch.aten.view %3604, %3605 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3606, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %3607 = torch.prim.ListConstruct %3601 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_3635 = torch.constant.bool false + %3608 = torch.aten.index_put %3606, %3607, %3602, %false_3635 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3608, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_3636 = torch.constant.int 32 + %int2_3637 = torch.constant.int 2 %int8_3638 = torch.constant.int 8 - %int128_3639 = torch.constant.int 128 - %3076 = torch.prim.ListConstruct %int4_3636, %int1_3637, %int8_3638, %int128_3639 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3077 = torch.aten.view %3071, %3076 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_3640 = torch.constant.int 6 - %3078 = torch.prims.convert_element_type %3073, %int6_3640 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %3079 = torch_c.to_builtin_tensor %3078 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %3080 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %3081 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%3079, %3080) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %3082 = torch_c.from_builtin_tensor %3081 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_3641 = torch.constant.int 5 - %3083 = torch.prims.convert_element_type %3082, %int5_3641 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_3642 = torch.constant.int 6 - %3084 = torch.prims.convert_element_type %3075, %int6_3642 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %3085 = torch_c.to_builtin_tensor %3084 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %3086 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %3087 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%3085, %3086) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %3088 = torch_c.from_builtin_tensor %3087 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_3643 = torch.constant.int 5 - %3089 = torch.prims.convert_element_type %3088, %int5_3643 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_3644 = torch.constant.int 32 - %3090 = torch.aten.floor_divide.Scalar %arg2, %int32_3644 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_3645 = torch.constant.int 1 - %3091 = torch.aten.unsqueeze %3090, %int1_3645 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3646 = torch.constant.int 1 - %false_3647 = torch.constant.bool false - %3092 = torch.aten.gather %arg3, %int1_3646, %3091, %false_3647 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_3648 = torch.constant.int 32 - %3093 = torch.aten.remainder.Scalar %arg2, %int32_3648 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int32_3639 = torch.constant.int 32 + %int128_3640 = torch.constant.int 128 + %3609 = torch.prim.ListConstruct %456, %int32_3636, %int2_3637, %int8_3638, %int32_3639, %int128_3640 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3610 = torch.aten.view %3608, %3609 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3610, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_3641 = torch.constant.int 2097152 + %3611 = torch.prim.ListConstruct %456, %int2097152_3641 : (!torch.int, !torch.int) -> !torch.list + %3612 = torch.aten.view %3610, %3611 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %3612, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_3642 = torch.constant.int 32 + %int2_3643 = torch.constant.int 2 + %int8_3644 = torch.constant.int 8 + %int32_3645 = torch.constant.int 32 + %int128_3646 = torch.constant.int 128 + %3613 = torch.prim.ListConstruct %456, %int32_3642, %int2_3643, %int8_3644, %int32_3645, %int128_3646 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3614 = torch.aten.view %3612, %3613 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3614, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_3647 = torch.constant.int 128 + %3615 = torch.prim.ListConstruct %596, %int128_3647 : (!torch.int, !torch.int) -> !torch.list + %3616 = torch.aten.view %3614, %3615 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3616, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_3648 = torch.constant.none + %3617 = torch.aten.clone %203, %none_3648 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3618 = torch.aten.detach %3617 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3619 = torch.aten.detach %3618 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3620 = torch.aten.detach %3619 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %int1_3649 = torch.constant.int 1 - %3094 = torch.aten.unsqueeze %3093, %int1_3649 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_3650 = torch.constant.none - %3095 = torch.aten.clone %148, %none_3650 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_3651 = torch.constant.int 0 - %3096 = torch.aten.unsqueeze %3095, %int0_3651 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_3652 = torch.constant.int 4 - %int1_3653 = torch.constant.int 1 - %3097 = torch.prim.ListConstruct %int4_3652, %int1_3653 : (!torch.int, !torch.int) -> !torch.list + %int1_3650 = torch.constant.int 1 + %int1_3651 = torch.constant.int 1 + %3621 = torch.prim.ListConstruct %int1_3649, %int1_3650, %int1_3651 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3622 = torch.aten.view %3620, %3621 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_3652 = torch.constant.int 32 + %3623 = torch.aten.mul.Scalar %3581, %int32_3652 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int14_3653 = torch.constant.int 14 %int1_3654 = torch.constant.int 1 - %int1_3655 = torch.constant.int 1 - %3098 = torch.prim.ListConstruct %int1_3654, %int1_3655 : (!torch.int, !torch.int) -> !torch.list - %int4_3656 = torch.constant.int 4 - %int0_3657 = torch.constant.int 0 - %cpu_3658 = torch.constant.device "cpu" - %false_3659 = torch.constant.bool false - %3099 = torch.aten.empty_strided %3097, %3098, %int4_3656, %int0_3657, %cpu_3658, %false_3659 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int13 = torch.constant.int 13 - %3100 = torch.aten.fill.Scalar %3099, %int13 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_3660 = torch.constant.int 4 - %int1_3661 = torch.constant.int 1 - %3101 = torch.prim.ListConstruct %int4_3660, %int1_3661 : (!torch.int, !torch.int) -> !torch.list - %3102 = torch.aten.repeat %3096, %3101 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_3662 = torch.constant.int 32 - %3103 = torch.aten.mul.Scalar %3092, %int32_3662 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3663 = torch.constant.int 1 - %3104 = torch.aten.add.Tensor %3103, %3100, %int1_3663 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %3624 = torch.aten.add.Scalar %3623, %int14_3653, %int1_3654 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_3655 = torch.constant.int 2 + %3625 = torch.aten.mul.Scalar %3624, %int2_3655 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_3656 = torch.constant.int 1 + %3626 = torch.aten.add.Tensor %3625, %3622, %int1_3656 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_3657 = torch.constant.int 8 + %3627 = torch.aten.mul.Scalar %3626, %int8_3657 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_3658 = torch.constant.int 1 + %3628 = torch.aten.add.Tensor %3627, %3587, %int1_3658 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_3659 = torch.constant.int 32 + %3629 = torch.aten.mul.Scalar %3628, %int32_3659 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_3660 = torch.constant.int 1 + %3630 = torch.aten.add.Tensor %3629, %3584, %int1_3660 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_3661 = torch.constant.int 5 + %3631 = torch.prims.convert_element_type %3556, %int5_3661 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %3632 = torch.prim.ListConstruct %3630 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_3662 = torch.constant.bool false + %3633 = torch.aten.index_put %3616, %3632, %3631, %false_3662 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3633, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_3663 = torch.constant.int 32 %int2_3664 = torch.constant.int 2 - %3105 = torch.aten.mul.Scalar %3104, %int2_3664 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3665 = torch.constant.int 1 - %3106 = torch.aten.add.Tensor %3105, %3102, %int1_3665 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int8_3665 = torch.constant.int 8 %int32_3666 = torch.constant.int 32 - %3107 = torch.aten.mul.Scalar %3106, %int32_3666 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3667 = torch.constant.int 1 - %3108 = torch.aten.add.Tensor %3107, %3094, %int1_3667 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_3668 = torch.constant.int 32 - %int2_3669 = torch.constant.int 2 - %int32_3670 = torch.constant.int 32 - %int8_3671 = torch.constant.int 8 - %int128_3672 = torch.constant.int 128 - %3109 = torch.prim.ListConstruct %437, %int32_3668, %int2_3669, %int32_3670, %int8_3671, %int128_3672 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3110 = torch.aten.view %2946, %3109 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3110, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_3673 = torch.constant.int 32 - %3111 = torch.aten.mul.int %437, %int32_3673 : !torch.int, !torch.int -> !torch.int - %int2_3674 = torch.constant.int 2 - %3112 = torch.aten.mul.int %3111, %int2_3674 : !torch.int, !torch.int -> !torch.int + %int128_3667 = torch.constant.int 128 + %3634 = torch.prim.ListConstruct %456, %int32_3663, %int2_3664, %int8_3665, %int32_3666, %int128_3667 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3635 = torch.aten.view %3633, %3634 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3635, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_3668 = torch.constant.int 2097152 + %3636 = torch.prim.ListConstruct %456, %int2097152_3668 : (!torch.int, !torch.int) -> !torch.list + %3637 = torch.aten.view %3635, %3636 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %3637, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_3669 = torch.constant.none + %3638 = torch.aten.clone %204, %none_3669 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3639 = torch.aten.detach %3638 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3640 = torch.aten.detach %3639 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3641 = torch.aten.detach %3640 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_3670 = torch.constant.none + %3642 = torch.aten.clone %205, %none_3670 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3643 = torch.aten.detach %3642 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3644 = torch.aten.detach %3643 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3645 = torch.aten.detach %3644 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_3671 = torch.constant.none + %3646 = torch.aten.clone %206, %none_3671 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3647 = torch.aten.detach %3646 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3648 = torch.aten.detach %3647 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3649 = torch.aten.detach %3648 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_3672 = torch.constant.int 32 + %int2_3673 = torch.constant.int 2 + %int8_3674 = torch.constant.int 8 %int32_3675 = torch.constant.int 32 - %3113 = torch.aten.mul.int %3112, %int32_3675 : !torch.int, !torch.int -> !torch.int - %int8_3676 = torch.constant.int 8 - %int128_3677 = torch.constant.int 128 - %3114 = torch.prim.ListConstruct %3113, %int8_3676, %int128_3677 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3115 = torch.aten.view %3110, %3114 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3115, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %3116 = torch.prim.ListConstruct %3108 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_3678 = torch.constant.bool false - %3117 = torch.aten.index_put %3115, %3116, %3089, %false_3678 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3117, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_3679 = torch.constant.int 32 - %int2_3680 = torch.constant.int 2 - %int32_3681 = torch.constant.int 32 - %int8_3682 = torch.constant.int 8 - %int128_3683 = torch.constant.int 128 - %3118 = torch.prim.ListConstruct %437, %int32_3679, %int2_3680, %int32_3681, %int8_3682, %int128_3683 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3119 = torch.aten.view %3117, %3118 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3119, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3684 = torch.constant.int 2097152 - %3120 = torch.prim.ListConstruct %437, %int2097152_3684 : (!torch.int, !torch.int) -> !torch.list - %3121 = torch.aten.view %3119, %3120 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3121, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_3685 = torch.constant.int 32 - %int2_3686 = torch.constant.int 2 - %int32_3687 = torch.constant.int 32 - %int8_3688 = torch.constant.int 8 - %int128_3689 = torch.constant.int 128 - %3122 = torch.prim.ListConstruct %437, %int32_3685, %int2_3686, %int32_3687, %int8_3688, %int128_3689 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3123 = torch.aten.view %3121, %3122 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3123, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_3690 = torch.constant.int 8 - %int128_3691 = torch.constant.int 128 - %3124 = torch.prim.ListConstruct %3113, %int8_3690, %int128_3691 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3125 = torch.aten.view %3123, %3124 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3125, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_3692 = torch.constant.int 32 - %3126 = torch.aten.floor_divide.Scalar %arg2, %int32_3692 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_3693 = torch.constant.int 1 - %3127 = torch.aten.unsqueeze %3126, %int1_3693 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3694 = torch.constant.int 1 - %false_3695 = torch.constant.bool false - %3128 = torch.aten.gather %arg3, %int1_3694, %3127, %false_3695 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_3696 = torch.constant.int 32 - %3129 = torch.aten.remainder.Scalar %arg2, %int32_3696 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_3697 = torch.constant.int 1 - %3130 = torch.aten.unsqueeze %3129, %int1_3697 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_3698 = torch.constant.none - %3131 = torch.aten.clone %149, %none_3698 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %int128_3676 = torch.constant.int 128 + %3650 = torch.prim.ListConstruct %456, %int32_3672, %int2_3673, %int8_3674, %int32_3675, %int128_3676 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3651 = torch.aten.view %3637, %3650 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3651, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %3652 = torch_c.to_builtin_tensor %3651 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %3653 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_3677 = tensor.cast %3653 : tensor<4x?xi64> to tensor + %3654 = torch_c.to_builtin_tensor %3641 : !torch.vtensor<[],si64> -> tensor + %3655 = torch_c.to_builtin_tensor %3645 : !torch.vtensor<[],si64> -> tensor + %3656 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%3652, %cast_3677, %3654, %3655) : (tensor, tensor, tensor, tensor) -> tensor + %cast_3678 = tensor.cast %3656 : tensor to tensor<4x?x8x32x128xf16> + %3657 = torch_c.from_builtin_tensor %cast_3678 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %3657, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %3658 = torch_c.to_builtin_tensor %3651 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %3659 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_3679 = tensor.cast %3659 : tensor<4x?xi64> to tensor + %3660 = torch_c.to_builtin_tensor %3641 : !torch.vtensor<[],si64> -> tensor + %3661 = torch_c.to_builtin_tensor %3649 : !torch.vtensor<[],si64> -> tensor + %3662 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%3658, %cast_3679, %3660, %3661) : (tensor, tensor, tensor, tensor) -> tensor + %cast_3680 = tensor.cast %3662 : tensor to tensor<4x?x8x32x128xf16> + %3663 = torch_c.from_builtin_tensor %cast_3680 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %3663, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_3681 = torch.constant.int 2 + %int3_3682 = torch.constant.int 3 + %3664 = torch.aten.transpose.int %3657, %int2_3681, %int3_3682 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3664, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_3683 = torch.constant.int 0 + %3665 = torch.aten.clone %3664, %int0_3683 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3665, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_3684 = torch.constant.int 4 + %int8_3685 = torch.constant.int 8 + %int128_3686 = torch.constant.int 128 + %3666 = torch.prim.ListConstruct %int4_3684, %457, %int8_3685, %int128_3686 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3667 = torch.aten._unsafe_view %3665, %3666 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3667, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_3687 = torch.constant.int 2 + %int3_3688 = torch.constant.int 3 + %3668 = torch.aten.transpose.int %3663, %int2_3687, %int3_3688 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3668, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_3689 = torch.constant.int 0 + %3669 = torch.aten.clone %3668, %int0_3689 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3669, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_3690 = torch.constant.int 4 + %int8_3691 = torch.constant.int 8 + %int128_3692 = torch.constant.int 128 + %3670 = torch.prim.ListConstruct %int4_3690, %457, %int8_3691, %int128_3692 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3671 = torch.aten._unsafe_view %3669, %3670 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3671, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_3693 = torch.constant.int -2 + %3672 = torch.aten.unsqueeze %3667, %int-2_3693 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3672, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_3694 = torch.constant.int 4 + %int8_3695 = torch.constant.int 8 + %int4_3696 = torch.constant.int 4 + %int128_3697 = torch.constant.int 128 + %3673 = torch.prim.ListConstruct %int4_3694, %457, %int8_3695, %int4_3696, %int128_3697 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_3698 = torch.constant.bool false + %3674 = torch.aten.expand %3672, %3673, %false_3698 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3674, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int0_3699 = torch.constant.int 0 - %3132 = torch.aten.unsqueeze %3131, %int0_3699 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %3675 = torch.aten.clone %3674, %int0_3699 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3675, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_3700 = torch.constant.int 4 - %int1_3701 = torch.constant.int 1 - %3133 = torch.prim.ListConstruct %int4_3700, %int1_3701 : (!torch.int, !torch.int) -> !torch.list - %int1_3702 = torch.constant.int 1 - %int1_3703 = torch.constant.int 1 - %3134 = torch.prim.ListConstruct %int1_3702, %int1_3703 : (!torch.int, !torch.int) -> !torch.list + %int32_3701 = torch.constant.int 32 + %int128_3702 = torch.constant.int 128 + %3676 = torch.prim.ListConstruct %int4_3700, %457, %int32_3701, %int128_3702 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3677 = torch.aten._unsafe_view %3675, %3676 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3677, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_3703 = torch.constant.int -2 + %3678 = torch.aten.unsqueeze %3671, %int-2_3703 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3678, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> %int4_3704 = torch.constant.int 4 - %int0_3705 = torch.constant.int 0 - %cpu_3706 = torch.constant.device "cpu" - %false_3707 = torch.constant.bool false - %3135 = torch.aten.empty_strided %3133, %3134, %int4_3704, %int0_3705, %cpu_3706, %false_3707 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int13_3708 = torch.constant.int 13 - %3136 = torch.aten.fill.Scalar %3135, %int13_3708 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_3709 = torch.constant.int 4 - %int1_3710 = torch.constant.int 1 - %3137 = torch.prim.ListConstruct %int4_3709, %int1_3710 : (!torch.int, !torch.int) -> !torch.list - %3138 = torch.aten.repeat %3132, %3137 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> + %int8_3705 = torch.constant.int 8 + %int4_3706 = torch.constant.int 4 + %int128_3707 = torch.constant.int 128 + %3679 = torch.prim.ListConstruct %int4_3704, %457, %int8_3705, %int4_3706, %int128_3707 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_3708 = torch.constant.bool false + %3680 = torch.aten.expand %3678, %3679, %false_3708 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3680, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_3709 = torch.constant.int 0 + %3681 = torch.aten.clone %3680, %int0_3709 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3681, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_3710 = torch.constant.int 4 %int32_3711 = torch.constant.int 32 - %3139 = torch.aten.mul.Scalar %3128, %int32_3711 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3712 = torch.constant.int 1 - %3140 = torch.aten.add.Tensor %3139, %3136, %int1_3712 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_3713 = torch.constant.int 2 - %3141 = torch.aten.mul.Scalar %3140, %int2_3713 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3714 = torch.constant.int 1 - %3142 = torch.aten.add.Tensor %3141, %3138, %int1_3714 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_3715 = torch.constant.int 32 - %3143 = torch.aten.mul.Scalar %3142, %int32_3715 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3716 = torch.constant.int 1 - %3144 = torch.aten.add.Tensor %3143, %3130, %int1_3716 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %3145 = torch.prim.ListConstruct %3144 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_3717 = torch.constant.bool false - %3146 = torch.aten.index_put %3125, %3145, %3077, %false_3717 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3146, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_3718 = torch.constant.int 32 - %int2_3719 = torch.constant.int 2 - %int32_3720 = torch.constant.int 32 - %int8_3721 = torch.constant.int 8 - %int128_3722 = torch.constant.int 128 - %3147 = torch.prim.ListConstruct %437, %int32_3718, %int2_3719, %int32_3720, %int8_3721, %int128_3722 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3148 = torch.aten.view %3146, %3147 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3148, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3723 = torch.constant.int 2097152 - %3149 = torch.prim.ListConstruct %437, %int2097152_3723 : (!torch.int, !torch.int) -> !torch.list - %3150 = torch.aten.view %3148, %3149 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3150, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int128_3712 = torch.constant.int 128 + %3682 = torch.prim.ListConstruct %int4_3710, %457, %int32_3711, %int128_3712 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3683 = torch.aten._unsafe_view %3681, %3682 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3683, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_3713 = torch.constant.int 1 + %int2_3714 = torch.constant.int 2 + %3684 = torch.aten.transpose.int %3566, %int1_3713, %int2_3714 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_3715 = torch.constant.int 1 + %int2_3716 = torch.constant.int 2 + %3685 = torch.aten.transpose.int %3677, %int1_3715, %int2_3716 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3685, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_3717 = torch.constant.int 1 + %int2_3718 = torch.constant.int 2 + %3686 = torch.aten.transpose.int %3683, %int1_3717, %int2_3718 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3686, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_3719 = torch.constant.float 0.000000e+00 + %false_3720 = torch.constant.bool false + %none_3721 = torch.constant.none + %3687:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3684, %3685, %3686, %float0.000000e00_3719, %false_3720, %470, %none_3721) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_3722 = torch.constant.int 1 + %int2_3723 = torch.constant.int 2 + %3688 = torch.aten.transpose.int %3687#0, %int1_3722, %int2_3723 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> %int4_3724 = torch.constant.int 4 - %3151 = torch.prim.ListConstruct %int4_3724, %358 : (!torch.int, !torch.int) -> !torch.list %int1_3725 = torch.constant.int 1 - %3152 = torch.prim.ListConstruct %358, %int1_3725 : (!torch.int, !torch.int) -> !torch.list - %int4_3726 = torch.constant.int 4 - %int0_3727 = torch.constant.int 0 - %cpu_3728 = torch.constant.device "cpu" - %false_3729 = torch.constant.bool false - %3153 = torch.aten.empty_strided %3151, %3152, %int4_3726, %int0_3727, %cpu_3728, %false_3729 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3153, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int13_3730 = torch.constant.int 13 - %3154 = torch.aten.fill.Scalar %3153, %int13_3730 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3154, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_3731 = torch.constant.int 32 - %3155 = torch.aten.mul.Scalar %arg3, %int32_3731 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3155, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_3732 = torch.constant.int 1 - %3156 = torch.aten.add.Tensor %3155, %3154, %int1_3732 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3156, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_3733 = torch.constant.int 4 - %3157 = torch.aten.mul.int %int4_3733, %358 : !torch.int, !torch.int -> !torch.int - %3158 = torch.prim.ListConstruct %3157 : (!torch.int) -> !torch.list - %3159 = torch.aten.view %3156, %3158 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3159, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_3734 = torch.constant.int 32 - %int2_3735 = torch.constant.int 2 - %int32_3736 = torch.constant.int 32 - %int8_3737 = torch.constant.int 8 - %int128_3738 = torch.constant.int 128 - %3160 = torch.prim.ListConstruct %437, %int32_3734, %int2_3735, %int32_3736, %int8_3737, %int128_3738 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3161 = torch.aten.view %3150, %3160 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3161, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_3739 = torch.constant.int 32 - %3162 = torch.aten.mul.int %437, %int32_3739 : !torch.int, !torch.int -> !torch.int - %int2_3740 = torch.constant.int 2 - %int32_3741 = torch.constant.int 32 - %int8_3742 = torch.constant.int 8 - %int128_3743 = torch.constant.int 128 - %3163 = torch.prim.ListConstruct %3162, %int2_3740, %int32_3741, %int8_3742, %int128_3743 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3164 = torch.aten.view %3161, %3163 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %3164, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_3744 = torch.constant.int 0 - %3165 = torch.aten.index_select %3164, %int0_3744, %3159 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %3165, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_3745 = torch.constant.int 4 - %int2_3746 = torch.constant.int 2 - %int32_3747 = torch.constant.int 32 - %int8_3748 = torch.constant.int 8 - %int128_3749 = torch.constant.int 128 - %3166 = torch.prim.ListConstruct %int4_3745, %358, %int2_3746, %int32_3747, %int8_3748, %int128_3749 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3167 = torch.aten.view %3165, %3166 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3167, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_3750 = torch.constant.int 0 - %int0_3751 = torch.constant.int 0 - %int9223372036854775807_3752 = torch.constant.int 9223372036854775807 - %int1_3753 = torch.constant.int 1 - %3168 = torch.aten.slice.Tensor %3167, %int0_3750, %int0_3751, %int9223372036854775807_3752, %int1_3753 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3168, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_3754 = torch.constant.int 1 - %int0_3755 = torch.constant.int 0 - %int9223372036854775807_3756 = torch.constant.int 9223372036854775807 - %int1_3757 = torch.constant.int 1 - %3169 = torch.aten.slice.Tensor %3168, %int1_3754, %int0_3755, %int9223372036854775807_3756, %int1_3757 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3169, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_3758 = torch.constant.int 2 - %int0_3759 = torch.constant.int 0 - %3170 = torch.aten.select.int %3169, %int2_3758, %int0_3759 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3170, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_3760 = torch.constant.int 32 - %3171 = torch.aten.mul.int %358, %int32_3760 : !torch.int, !torch.int -> !torch.int - %int2_3761 = torch.constant.int 2 - %int0_3762 = torch.constant.int 0 - %int1_3763 = torch.constant.int 1 - %3172 = torch.aten.slice.Tensor %3170, %int2_3761, %int0_3762, %3171, %int1_3763 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3172, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_3764 = torch.constant.int 0 - %3173 = torch.aten.clone %3172, %int0_3764 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3173, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_3765 = torch.constant.int 1 - %3174 = torch.aten.size.int %3169, %int1_3765 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_3766 = torch.constant.int 32 - %3175 = torch.aten.mul.int %3174, %int32_3766 : !torch.int, !torch.int -> !torch.int - %int4_3767 = torch.constant.int 4 - %int8_3768 = torch.constant.int 8 - %int128_3769 = torch.constant.int 128 - %3176 = torch.prim.ListConstruct %int4_3767, %3175, %int8_3768, %int128_3769 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3177 = torch.aten._unsafe_view %3173, %3176 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3177, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_3770 = torch.constant.int 0 - %int0_3771 = torch.constant.int 0 - %int9223372036854775807_3772 = torch.constant.int 9223372036854775807 - %int1_3773 = torch.constant.int 1 - %3178 = torch.aten.slice.Tensor %3177, %int0_3770, %int0_3771, %int9223372036854775807_3772, %int1_3773 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3178, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_3774 = torch.constant.int 0 - %int0_3775 = torch.constant.int 0 - %int9223372036854775807_3776 = torch.constant.int 9223372036854775807 - %int1_3777 = torch.constant.int 1 - %3179 = torch.aten.slice.Tensor %3167, %int0_3774, %int0_3775, %int9223372036854775807_3776, %int1_3777 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3179, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_3778 = torch.constant.int 1 - %int0_3779 = torch.constant.int 0 - %int9223372036854775807_3780 = torch.constant.int 9223372036854775807 - %int1_3781 = torch.constant.int 1 - %3180 = torch.aten.slice.Tensor %3179, %int1_3778, %int0_3779, %int9223372036854775807_3780, %int1_3781 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3180, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_3782 = torch.constant.int 2 - %int1_3783 = torch.constant.int 1 - %3181 = torch.aten.select.int %3180, %int2_3782, %int1_3783 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3181, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_3784 = torch.constant.int 2 - %int0_3785 = torch.constant.int 0 - %int1_3786 = torch.constant.int 1 - %3182 = torch.aten.slice.Tensor %3181, %int2_3784, %int0_3785, %3171, %int1_3786 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3182, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_3787 = torch.constant.int 0 - %3183 = torch.aten.clone %3182, %int0_3787 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3183, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_3788 = torch.constant.int 1 - %3184 = torch.aten.size.int %3180, %int1_3788 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_3789 = torch.constant.int 32 - %3185 = torch.aten.mul.int %3184, %int32_3789 : !torch.int, !torch.int -> !torch.int + %int4096_3726 = torch.constant.int 4096 + %3689 = torch.prim.ListConstruct %int4_3724, %int1_3725, %int4096_3726 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3690 = torch.aten.view %3688, %3689 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_3727 = torch.constant.int -2 + %int-1_3728 = torch.constant.int -1 + %3691 = torch.aten.transpose.int %207, %int-2_3727, %int-1_3728 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_3729 = torch.constant.int 5 + %3692 = torch.prims.convert_element_type %3691, %int5_3729 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_3730 = torch.constant.int 4 + %int4096_3731 = torch.constant.int 4096 + %3693 = torch.prim.ListConstruct %int4_3730, %int4096_3731 : (!torch.int, !torch.int) -> !torch.list + %3694 = torch.aten.view %3690, %3693 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3695 = torch.aten.mm %3694, %3692 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_3732 = torch.constant.int 4 + %int1_3733 = torch.constant.int 1 + %int4096_3734 = torch.constant.int 4096 + %3696 = torch.prim.ListConstruct %int4_3732, %int1_3733, %int4096_3734 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3697 = torch.aten.view %3695, %3696 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_3735 = torch.constant.int 1 + %3698 = torch.aten.add.Tensor %3519, %3697, %int1_3735 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_3736 = torch.constant.int 6 + %3699 = torch.prims.convert_element_type %3698, %int6_3736 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_3737 = torch.constant.int 2 + %3700 = torch.aten.pow.Tensor_Scalar %3699, %int2_3737 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_3738 = torch.constant.int -1 + %3701 = torch.prim.ListConstruct %int-1_3738 : (!torch.int) -> !torch.list + %true_3739 = torch.constant.bool true + %none_3740 = torch.constant.none + %3702 = torch.aten.mean.dim %3700, %3701, %true_3739, %none_3740 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_3741 = torch.constant.float 9.9999997473787516E-6 + %int1_3742 = torch.constant.int 1 + %3703 = torch.aten.add.Scalar %3702, %float9.999990e-06_3741, %int1_3742 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %3704 = torch.aten.rsqrt %3703 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %3705 = torch.aten.mul.Tensor %3699, %3704 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_3743 = torch.constant.int 5 + %3706 = torch.prims.convert_element_type %3705, %int5_3743 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %3707 = torch.aten.mul.Tensor %208, %3706 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_3744 = torch.constant.int 5 + %3708 = torch.prims.convert_element_type %3707, %int5_3744 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_3745 = torch.constant.int -2 + %int-1_3746 = torch.constant.int -1 + %3709 = torch.aten.transpose.int %209, %int-2_3745, %int-1_3746 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3747 = torch.constant.int 5 + %3710 = torch.prims.convert_element_type %3709, %int5_3747 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_3748 = torch.constant.int 4 + %int4096_3749 = torch.constant.int 4096 + %3711 = torch.prim.ListConstruct %int4_3748, %int4096_3749 : (!torch.int, !torch.int) -> !torch.list + %3712 = torch.aten.view %3708, %3711 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3713 = torch.aten.mm %3712, %3710 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_3750 = torch.constant.int 4 + %int1_3751 = torch.constant.int 1 + %int14336_3752 = torch.constant.int 14336 + %3714 = torch.prim.ListConstruct %int4_3750, %int1_3751, %int14336_3752 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3715 = torch.aten.view %3713, %3714 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %3716 = torch.aten.silu %3715 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_3753 = torch.constant.int -2 + %int-1_3754 = torch.constant.int -1 + %3717 = torch.aten.transpose.int %210, %int-2_3753, %int-1_3754 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3755 = torch.constant.int 5 + %3718 = torch.prims.convert_element_type %3717, %int5_3755 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_3756 = torch.constant.int 4 + %int4096_3757 = torch.constant.int 4096 + %3719 = torch.prim.ListConstruct %int4_3756, %int4096_3757 : (!torch.int, !torch.int) -> !torch.list + %3720 = torch.aten.view %3708, %3719 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3721 = torch.aten.mm %3720, %3718 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_3758 = torch.constant.int 4 + %int1_3759 = torch.constant.int 1 + %int14336_3760 = torch.constant.int 14336 + %3722 = torch.prim.ListConstruct %int4_3758, %int1_3759, %int14336_3760 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3723 = torch.aten.view %3721, %3722 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %3724 = torch.aten.mul.Tensor %3716, %3723 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_3761 = torch.constant.int -2 + %int-1_3762 = torch.constant.int -1 + %3725 = torch.aten.transpose.int %211, %int-2_3761, %int-1_3762 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_3763 = torch.constant.int 5 + %3726 = torch.prims.convert_element_type %3725, %int5_3763 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_3764 = torch.constant.int 4 + %int14336_3765 = torch.constant.int 14336 + %3727 = torch.prim.ListConstruct %int4_3764, %int14336_3765 : (!torch.int, !torch.int) -> !torch.list + %3728 = torch.aten.view %3724, %3727 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %3729 = torch.aten.mm %3728, %3726 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_3766 = torch.constant.int 4 + %int1_3767 = torch.constant.int 1 + %int4096_3768 = torch.constant.int 4096 + %3730 = torch.prim.ListConstruct %int4_3766, %int1_3767, %int4096_3768 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3731 = torch.aten.view %3729, %3730 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_3769 = torch.constant.int 1 + %3732 = torch.aten.add.Tensor %3698, %3731, %int1_3769 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_3770 = torch.constant.int 6 + %3733 = torch.prims.convert_element_type %3732, %int6_3770 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_3771 = torch.constant.int 2 + %3734 = torch.aten.pow.Tensor_Scalar %3733, %int2_3771 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_3772 = torch.constant.int -1 + %3735 = torch.prim.ListConstruct %int-1_3772 : (!torch.int) -> !torch.list + %true_3773 = torch.constant.bool true + %none_3774 = torch.constant.none + %3736 = torch.aten.mean.dim %3734, %3735, %true_3773, %none_3774 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_3775 = torch.constant.float 9.9999997473787516E-6 + %int1_3776 = torch.constant.int 1 + %3737 = torch.aten.add.Scalar %3736, %float9.999990e-06_3775, %int1_3776 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %3738 = torch.aten.rsqrt %3737 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %3739 = torch.aten.mul.Tensor %3733, %3738 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_3777 = torch.constant.int 5 + %3740 = torch.prims.convert_element_type %3739, %int5_3777 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %3741 = torch.aten.mul.Tensor %212, %3740 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_3778 = torch.constant.int 5 + %3742 = torch.prims.convert_element_type %3741, %int5_3778 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_3779 = torch.constant.int -2 + %int-1_3780 = torch.constant.int -1 + %3743 = torch.aten.transpose.int %213, %int-2_3779, %int-1_3780 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_3781 = torch.constant.int 5 + %3744 = torch.prims.convert_element_type %3743, %int5_3781 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_3782 = torch.constant.int 4 + %int4096_3783 = torch.constant.int 4096 + %3745 = torch.prim.ListConstruct %int4_3782, %int4096_3783 : (!torch.int, !torch.int) -> !torch.list + %3746 = torch.aten.view %3742, %3745 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3747 = torch.aten.mm %3746, %3744 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_3784 = torch.constant.int 4 + %int1_3785 = torch.constant.int 1 + %int4096_3786 = torch.constant.int 4096 + %3748 = torch.prim.ListConstruct %int4_3784, %int1_3785, %int4096_3786 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3749 = torch.aten.view %3747, %3748 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_3787 = torch.constant.int -2 + %int-1_3788 = torch.constant.int -1 + %3750 = torch.aten.transpose.int %214, %int-2_3787, %int-1_3788 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_3789 = torch.constant.int 5 + %3751 = torch.prims.convert_element_type %3750, %int5_3789 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> %int4_3790 = torch.constant.int 4 - %int8_3791 = torch.constant.int 8 - %int128_3792 = torch.constant.int 128 - %3186 = torch.prim.ListConstruct %int4_3790, %3185, %int8_3791, %int128_3792 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3187 = torch.aten._unsafe_view %3183, %3186 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3187, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_3793 = torch.constant.int 0 - %int0_3794 = torch.constant.int 0 - %int9223372036854775807_3795 = torch.constant.int 9223372036854775807 - %int1_3796 = torch.constant.int 1 - %3188 = torch.aten.slice.Tensor %3187, %int0_3793, %int0_3794, %int9223372036854775807_3795, %int1_3796 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3188, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_3797 = torch.constant.int -2 - %3189 = torch.aten.unsqueeze %3178, %int-2_3797 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3189, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_3798 = torch.constant.int 1 - %3190 = torch.aten.size.int %3177, %int1_3798 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_3799 = torch.constant.int 4 - %int8_3800 = torch.constant.int 8 - %int4_3801 = torch.constant.int 4 - %int128_3802 = torch.constant.int 128 - %3191 = torch.prim.ListConstruct %int4_3799, %3190, %int8_3800, %int4_3801, %int128_3802 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_3803 = torch.constant.bool false - %3192 = torch.aten.expand %3189, %3191, %false_3803 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3192, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_3804 = torch.constant.int 0 - %3193 = torch.aten.clone %3192, %int0_3804 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3193, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_3805 = torch.constant.int 4 - %int32_3806 = torch.constant.int 32 - %int128_3807 = torch.constant.int 128 - %3194 = torch.prim.ListConstruct %int4_3805, %3190, %int32_3806, %int128_3807 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3195 = torch.aten._unsafe_view %3193, %3194 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3195, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_3808 = torch.constant.int -2 - %3196 = torch.aten.unsqueeze %3188, %int-2_3808 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3196, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_3809 = torch.constant.int 1 - %3197 = torch.aten.size.int %3187, %int1_3809 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_3810 = torch.constant.int 4 - %int8_3811 = torch.constant.int 8 - %int4_3812 = torch.constant.int 4 - %int128_3813 = torch.constant.int 128 - %3198 = torch.prim.ListConstruct %int4_3810, %3197, %int8_3811, %int4_3812, %int128_3813 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_3814 = torch.constant.bool false - %3199 = torch.aten.expand %3196, %3198, %false_3814 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3199, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_3815 = torch.constant.int 0 - %3200 = torch.aten.clone %3199, %int0_3815 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3200, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_3816 = torch.constant.int 4 - %int32_3817 = torch.constant.int 32 - %int128_3818 = torch.constant.int 128 - %3201 = torch.prim.ListConstruct %int4_3816, %3197, %int32_3817, %int128_3818 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3202 = torch.aten._unsafe_view %3200, %3201 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3202, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_3819 = torch.constant.int 1 - %int2_3820 = torch.constant.int 2 - %3203 = torch.aten.transpose.int %3083, %int1_3819, %int2_3820 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_3821 = torch.constant.int 1 - %int2_3822 = torch.constant.int 2 - %3204 = torch.aten.transpose.int %3195, %int1_3821, %int2_3822 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3204, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_3823 = torch.constant.int 1 - %int2_3824 = torch.constant.int 2 - %3205 = torch.aten.transpose.int %3202, %int1_3823, %int2_3824 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3205, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_3825 = torch.constant.float 0.000000e+00 - %false_3826 = torch.constant.bool false - %none_3827 = torch.constant.none - %3206:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3203, %3204, %3205, %float0.000000e00_3825, %false_3826, %368, %none_3827) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_3828 = torch.constant.int 1 - %int2_3829 = torch.constant.int 2 - %3207 = torch.aten.transpose.int %3206#0, %int1_3828, %int2_3829 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_3830 = torch.constant.int 4 - %int1_3831 = torch.constant.int 1 - %int4096_3832 = torch.constant.int 4096 - %3208 = torch.prim.ListConstruct %int4_3830, %int1_3831, %int4096_3832 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3209 = torch.aten.view %3207, %3208 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_3833 = torch.constant.int -2 - %int-1_3834 = torch.constant.int -1 - %3210 = torch.aten.transpose.int %150, %int-2_3833, %int-1_3834 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3835 = torch.constant.int 4 - %int4096_3836 = torch.constant.int 4096 - %3211 = torch.prim.ListConstruct %int4_3835, %int4096_3836 : (!torch.int, !torch.int) -> !torch.list - %3212 = torch.aten.view %3209, %3211 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3213 = torch.aten.mm %3212, %3210 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_3837 = torch.constant.int 4 + %int4096_3791 = torch.constant.int 4096 + %3752 = torch.prim.ListConstruct %int4_3790, %int4096_3791 : (!torch.int, !torch.int) -> !torch.list + %3753 = torch.aten.view %3742, %3752 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3754 = torch.aten.mm %3753, %3751 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_3792 = torch.constant.int 4 + %int1_3793 = torch.constant.int 1 + %int1024_3794 = torch.constant.int 1024 + %3755 = torch.prim.ListConstruct %int4_3792, %int1_3793, %int1024_3794 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3756 = torch.aten.view %3754, %3755 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_3795 = torch.constant.int -2 + %int-1_3796 = torch.constant.int -1 + %3757 = torch.aten.transpose.int %215, %int-2_3795, %int-1_3796 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_3797 = torch.constant.int 5 + %3758 = torch.prims.convert_element_type %3757, %int5_3797 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_3798 = torch.constant.int 4 + %int4096_3799 = torch.constant.int 4096 + %3759 = torch.prim.ListConstruct %int4_3798, %int4096_3799 : (!torch.int, !torch.int) -> !torch.list + %3760 = torch.aten.view %3742, %3759 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3761 = torch.aten.mm %3760, %3758 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_3800 = torch.constant.int 4 + %int1_3801 = torch.constant.int 1 + %int1024_3802 = torch.constant.int 1024 + %3762 = torch.prim.ListConstruct %int4_3800, %int1_3801, %int1024_3802 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3763 = torch.aten.view %3761, %3762 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_3803 = torch.constant.int 4 + %int1_3804 = torch.constant.int 1 + %int32_3805 = torch.constant.int 32 + %int128_3806 = torch.constant.int 128 + %3764 = torch.prim.ListConstruct %int4_3803, %int1_3804, %int32_3805, %int128_3806 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3765 = torch.aten.view %3749, %3764 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_3807 = torch.constant.int 4 + %int1_3808 = torch.constant.int 1 + %int8_3809 = torch.constant.int 8 + %int128_3810 = torch.constant.int 128 + %3766 = torch.prim.ListConstruct %int4_3807, %int1_3808, %int8_3809, %int128_3810 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3767 = torch.aten.view %3756, %3766 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_3811 = torch.constant.int 4 + %int1_3812 = torch.constant.int 1 + %int8_3813 = torch.constant.int 8 + %int128_3814 = torch.constant.int 128 + %3768 = torch.prim.ListConstruct %int4_3811, %int1_3812, %int8_3813, %int128_3814 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3769 = torch.aten.view %3763, %3768 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_3815 = torch.constant.int 1 + %int2_3816 = torch.constant.int 2 + %3770 = torch.aten.transpose.int %3765, %int1_3815, %int2_3816 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %3771 = torch.aten.mul.Tensor %3770, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_3817 = torch.constant.int 3 + %int0_3818 = torch.constant.int 0 + %int64_3819 = torch.constant.int 64 + %int1_3820 = torch.constant.int 1 + %3772 = torch.aten.slice.Tensor %3770, %int3_3817, %int0_3818, %int64_3819, %int1_3820 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_3821 = torch.constant.int 3 + %int64_3822 = torch.constant.int 64 + %int9223372036854775807_3823 = torch.constant.int 9223372036854775807 + %int1_3824 = torch.constant.int 1 + %3773 = torch.aten.slice.Tensor %3770, %int3_3821, %int64_3822, %int9223372036854775807_3823, %int1_3824 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %3774 = torch.aten.neg %3773 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %3775 = torch.prim.ListConstruct %3774, %3772 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_3825 = torch.constant.int -1 + %3776 = torch.aten.cat %3775, %int-1_3825 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %3777 = torch.aten.mul.Tensor %3776, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_3826 = torch.constant.int 1 + %3778 = torch.aten.add.Tensor %3771, %3777, %int1_3826 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_3827 = torch.constant.int 1 + %int2_3828 = torch.constant.int 2 + %3779 = torch.aten.transpose.int %3778, %int1_3827, %int2_3828 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_3829 = torch.constant.int 1 + %int2_3830 = torch.constant.int 2 + %3780 = torch.aten.transpose.int %3767, %int1_3829, %int2_3830 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %3781 = torch.aten.mul.Tensor %3780, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_3831 = torch.constant.int 3 + %int0_3832 = torch.constant.int 0 + %int64_3833 = torch.constant.int 64 + %int1_3834 = torch.constant.int 1 + %3782 = torch.aten.slice.Tensor %3780, %int3_3831, %int0_3832, %int64_3833, %int1_3834 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_3835 = torch.constant.int 3 + %int64_3836 = torch.constant.int 64 + %int9223372036854775807_3837 = torch.constant.int 9223372036854775807 %int1_3838 = torch.constant.int 1 - %int4096_3839 = torch.constant.int 4096 - %3214 = torch.prim.ListConstruct %int4_3837, %int1_3838, %int4096_3839 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3215 = torch.aten.view %3213, %3214 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %3783 = torch.aten.slice.Tensor %3780, %int3_3835, %int64_3836, %int9223372036854775807_3837, %int1_3838 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %3784 = torch.aten.neg %3783 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %3785 = torch.prim.ListConstruct %3784, %3782 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_3839 = torch.constant.int -1 + %3786 = torch.aten.cat %3785, %int-1_3839 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %3787 = torch.aten.mul.Tensor %3786, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> %int1_3840 = torch.constant.int 1 - %3216 = torch.aten.add.Tensor %3043, %3215, %int1_3840 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_3841 = torch.constant.int 6 - %3217 = torch.prims.convert_element_type %3216, %int6_3841 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %3788 = torch.aten.add.Tensor %3781, %3787, %int1_3840 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_3841 = torch.constant.int 1 %int2_3842 = torch.constant.int 2 - %3218 = torch.aten.pow.Tensor_Scalar %3217, %int2_3842 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_3843 = torch.constant.int -1 - %3219 = torch.prim.ListConstruct %int-1_3843 : (!torch.int) -> !torch.list - %true_3844 = torch.constant.bool true - %none_3845 = torch.constant.none - %3220 = torch.aten.mean.dim %3218, %3219, %true_3844, %none_3845 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_3846 = torch.constant.float 9.9999997473787516E-6 - %int1_3847 = torch.constant.int 1 - %3221 = torch.aten.add.Scalar %3220, %float9.999990e-06_3846, %int1_3847 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %3222 = torch.aten.rsqrt %3221 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %3223 = torch.aten.mul.Tensor %3217, %3222 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_3848 = torch.constant.int 5 - %3224 = torch.prims.convert_element_type %3223, %int5_3848 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %3225 = torch.aten.mul.Tensor %151, %3224 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_3849 = torch.constant.int 5 - %3226 = torch.prims.convert_element_type %3225, %int5_3849 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_3850 = torch.constant.int -2 - %int-1_3851 = torch.constant.int -1 - %3227 = torch.aten.transpose.int %152, %int-2_3850, %int-1_3851 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3852 = torch.constant.int 4 - %int4096_3853 = torch.constant.int 4096 - %3228 = torch.prim.ListConstruct %int4_3852, %int4096_3853 : (!torch.int, !torch.int) -> !torch.list - %3229 = torch.aten.view %3226, %3228 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3230 = torch.aten.mm %3229, %3227 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_3854 = torch.constant.int 4 - %int1_3855 = torch.constant.int 1 - %int14336_3856 = torch.constant.int 14336 - %3231 = torch.prim.ListConstruct %int4_3854, %int1_3855, %int14336_3856 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3232 = torch.aten.view %3230, %3231 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %3233 = torch.aten.silu %3232 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_3857 = torch.constant.int -2 - %int-1_3858 = torch.constant.int -1 - %3234 = torch.aten.transpose.int %153, %int-2_3857, %int-1_3858 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_3859 = torch.constant.int 4 - %int4096_3860 = torch.constant.int 4096 - %3235 = torch.prim.ListConstruct %int4_3859, %int4096_3860 : (!torch.int, !torch.int) -> !torch.list - %3236 = torch.aten.view %3226, %3235 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3237 = torch.aten.mm %3236, %3234 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_3861 = torch.constant.int 4 - %int1_3862 = torch.constant.int 1 - %int14336_3863 = torch.constant.int 14336 - %3238 = torch.prim.ListConstruct %int4_3861, %int1_3862, %int14336_3863 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3239 = torch.aten.view %3237, %3238 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %3240 = torch.aten.mul.Tensor %3233, %3239 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_3864 = torch.constant.int -2 - %int-1_3865 = torch.constant.int -1 - %3241 = torch.aten.transpose.int %154, %int-2_3864, %int-1_3865 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_3866 = torch.constant.int 4 - %int14336_3867 = torch.constant.int 14336 - %3242 = torch.prim.ListConstruct %int4_3866, %int14336_3867 : (!torch.int, !torch.int) -> !torch.list - %3243 = torch.aten.view %3240, %3242 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %3244 = torch.aten.mm %3243, %3241 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_3868 = torch.constant.int 4 + %3789 = torch.aten.transpose.int %3788, %int1_3841, %int2_3842 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_3843 = torch.constant.int 32 + %3790 = torch.aten.floor_divide.Scalar %arg2, %int32_3843 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_3844 = torch.constant.int 1 + %3791 = torch.aten.unsqueeze %3790, %int1_3844 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_3845 = torch.constant.int 1 + %false_3846 = torch.constant.bool false + %3792 = torch.aten.gather %arg3, %int1_3845, %3791, %false_3846 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_3847 = torch.constant.int 4 + %int1_3848 = torch.constant.int 1 + %int1_3849 = torch.constant.int 1 + %3793 = torch.prim.ListConstruct %int4_3847, %int1_3848, %int1_3849 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3794 = torch.aten.view %3792, %3793 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_3850 = torch.constant.int 32 + %3795 = torch.aten.remainder.Scalar %arg2, %int32_3850 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_3851 = torch.constant.int 4 + %int1_3852 = torch.constant.int 1 + %int1_3853 = torch.constant.int 1 + %3796 = torch.prim.ListConstruct %int4_3851, %int1_3852, %int1_3853 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3797 = torch.aten.view %3795, %3796 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_3854 = torch.constant.int 8 + %none_3855 = torch.constant.none + %none_3856 = torch.constant.none + %cpu_3857 = torch.constant.device "cpu" + %false_3858 = torch.constant.bool false + %3798 = torch.aten.arange %int8_3854, %none_3855, %none_3856, %cpu_3857, %false_3858 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_3859 = torch.constant.int 1 + %int1_3860 = torch.constant.int 1 + %int8_3861 = torch.constant.int 8 + %3799 = torch.prim.ListConstruct %int1_3859, %int1_3860, %int8_3861 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3800 = torch.aten.view %3798, %3799 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_3862 = torch.constant.none + %3801 = torch.aten.clone %216, %none_3862 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3802 = torch.aten.detach %3801 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3803 = torch.aten.detach %3802 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3804 = torch.aten.detach %3803 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_3863 = torch.constant.int 1 + %int1_3864 = torch.constant.int 1 + %int1_3865 = torch.constant.int 1 + %3805 = torch.prim.ListConstruct %int1_3863, %int1_3864, %int1_3865 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3806 = torch.aten.view %3804, %3805 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_3866 = torch.constant.int 32 + %3807 = torch.aten.mul.Scalar %3794, %int32_3866 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int15 = torch.constant.int 15 + %int1_3867 = torch.constant.int 1 + %3808 = torch.aten.add.Scalar %3807, %int15, %int1_3867 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_3868 = torch.constant.int 2 + %3809 = torch.aten.mul.Scalar %3808, %int2_3868 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_3869 = torch.constant.int 1 - %int4096_3870 = torch.constant.int 4096 - %3245 = torch.prim.ListConstruct %int4_3868, %int1_3869, %int4096_3870 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3246 = torch.aten.view %3244, %3245 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %3810 = torch.aten.add.Tensor %3809, %3806, %int1_3869 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_3870 = torch.constant.int 8 + %3811 = torch.aten.mul.Scalar %3810, %int8_3870 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_3871 = torch.constant.int 1 - %3247 = torch.aten.add.Tensor %3216, %3246, %int1_3871 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_3872 = torch.constant.int 6 - %3248 = torch.prims.convert_element_type %3247, %int6_3872 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_3873 = torch.constant.int 2 - %3249 = torch.aten.pow.Tensor_Scalar %3248, %int2_3873 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_3874 = torch.constant.int -1 - %3250 = torch.prim.ListConstruct %int-1_3874 : (!torch.int) -> !torch.list - %true_3875 = torch.constant.bool true - %none_3876 = torch.constant.none - %3251 = torch.aten.mean.dim %3249, %3250, %true_3875, %none_3876 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_3877 = torch.constant.float 9.9999997473787516E-6 - %int1_3878 = torch.constant.int 1 - %3252 = torch.aten.add.Scalar %3251, %float9.999990e-06_3877, %int1_3878 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %3253 = torch.aten.rsqrt %3252 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %3254 = torch.aten.mul.Tensor %3248, %3253 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_3879 = torch.constant.int 5 - %3255 = torch.prims.convert_element_type %3254, %int5_3879 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %3256 = torch.aten.mul.Tensor %155, %3255 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_3880 = torch.constant.int 5 - %3257 = torch.prims.convert_element_type %3256, %int5_3880 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_3881 = torch.constant.int -2 - %int-1_3882 = torch.constant.int -1 - %3258 = torch.aten.transpose.int %156, %int-2_3881, %int-1_3882 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_3883 = torch.constant.int 4 - %int4096_3884 = torch.constant.int 4096 - %3259 = torch.prim.ListConstruct %int4_3883, %int4096_3884 : (!torch.int, !torch.int) -> !torch.list - %3260 = torch.aten.view %3257, %3259 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3261 = torch.aten.mm %3260, %3258 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_3885 = torch.constant.int 4 - %int1_3886 = torch.constant.int 1 - %int4096_3887 = torch.constant.int 4096 - %3262 = torch.prim.ListConstruct %int4_3885, %int1_3886, %int4096_3887 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3263 = torch.aten.view %3261, %3262 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_3888 = torch.constant.int -2 - %int-1_3889 = torch.constant.int -1 - %3264 = torch.aten.transpose.int %157, %int-2_3888, %int-1_3889 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3890 = torch.constant.int 4 - %int4096_3891 = torch.constant.int 4096 - %3265 = torch.prim.ListConstruct %int4_3890, %int4096_3891 : (!torch.int, !torch.int) -> !torch.list - %3266 = torch.aten.view %3257, %3265 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3267 = torch.aten.mm %3266, %3264 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_3892 = torch.constant.int 4 - %int1_3893 = torch.constant.int 1 - %int1024_3894 = torch.constant.int 1024 - %3268 = torch.prim.ListConstruct %int4_3892, %int1_3893, %int1024_3894 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3269 = torch.aten.view %3267, %3268 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_3895 = torch.constant.int -2 - %int-1_3896 = torch.constant.int -1 - %3270 = torch.aten.transpose.int %158, %int-2_3895, %int-1_3896 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_3897 = torch.constant.int 4 - %int4096_3898 = torch.constant.int 4096 - %3271 = torch.prim.ListConstruct %int4_3897, %int4096_3898 : (!torch.int, !torch.int) -> !torch.list - %3272 = torch.aten.view %3257, %3271 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3273 = torch.aten.mm %3272, %3270 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_3899 = torch.constant.int 4 + %3812 = torch.aten.add.Tensor %3811, %3800, %int1_3871 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_3872 = torch.constant.int 32 + %3813 = torch.aten.mul.Scalar %3812, %int32_3872 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_3873 = torch.constant.int 1 + %3814 = torch.aten.add.Tensor %3813, %3797, %int1_3873 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_3874 = torch.constant.int 5 + %3815 = torch.prims.convert_element_type %3789, %int5_3874 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_3875 = torch.constant.int 32 + %int2_3876 = torch.constant.int 2 + %int8_3877 = torch.constant.int 8 + %int32_3878 = torch.constant.int 32 + %int128_3879 = torch.constant.int 128 + %3816 = torch.prim.ListConstruct %456, %int32_3875, %int2_3876, %int8_3877, %int32_3878, %int128_3879 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3817 = torch.aten.view %3637, %3816 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3817, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_3880 = torch.constant.int 128 + %3818 = torch.prim.ListConstruct %596, %int128_3880 : (!torch.int, !torch.int) -> !torch.list + %3819 = torch.aten.view %3817, %3818 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3819, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %3820 = torch.prim.ListConstruct %3814 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_3881 = torch.constant.bool false + %3821 = torch.aten.index_put %3819, %3820, %3815, %false_3881 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3821, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_3882 = torch.constant.int 32 + %int2_3883 = torch.constant.int 2 + %int8_3884 = torch.constant.int 8 + %int32_3885 = torch.constant.int 32 + %int128_3886 = torch.constant.int 128 + %3822 = torch.prim.ListConstruct %456, %int32_3882, %int2_3883, %int8_3884, %int32_3885, %int128_3886 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3823 = torch.aten.view %3821, %3822 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3823, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_3887 = torch.constant.int 2097152 + %3824 = torch.prim.ListConstruct %456, %int2097152_3887 : (!torch.int, !torch.int) -> !torch.list + %3825 = torch.aten.view %3823, %3824 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %3825, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_3888 = torch.constant.int 32 + %int2_3889 = torch.constant.int 2 + %int8_3890 = torch.constant.int 8 + %int32_3891 = torch.constant.int 32 + %int128_3892 = torch.constant.int 128 + %3826 = torch.prim.ListConstruct %456, %int32_3888, %int2_3889, %int8_3890, %int32_3891, %int128_3892 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3827 = torch.aten.view %3825, %3826 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3827, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_3893 = torch.constant.int 128 + %3828 = torch.prim.ListConstruct %596, %int128_3893 : (!torch.int, !torch.int) -> !torch.list + %3829 = torch.aten.view %3827, %3828 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3829, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_3894 = torch.constant.none + %3830 = torch.aten.clone %217, %none_3894 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3831 = torch.aten.detach %3830 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3832 = torch.aten.detach %3831 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3833 = torch.aten.detach %3832 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_3895 = torch.constant.int 1 + %int1_3896 = torch.constant.int 1 + %int1_3897 = torch.constant.int 1 + %3834 = torch.prim.ListConstruct %int1_3895, %int1_3896, %int1_3897 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3835 = torch.aten.view %3833, %3834 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_3898 = torch.constant.int 32 + %3836 = torch.aten.mul.Scalar %3794, %int32_3898 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int15_3899 = torch.constant.int 15 %int1_3900 = torch.constant.int 1 - %int1024_3901 = torch.constant.int 1024 - %3274 = torch.prim.ListConstruct %int4_3899, %int1_3900, %int1024_3901 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3275 = torch.aten.view %3273, %3274 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_3902 = torch.constant.int 4 - %int1_3903 = torch.constant.int 1 - %int32_3904 = torch.constant.int 32 - %int128_3905 = torch.constant.int 128 - %3276 = torch.prim.ListConstruct %int4_3902, %int1_3903, %int32_3904, %int128_3905 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3277 = torch.aten.view %3263, %3276 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_3906 = torch.constant.int 4 - %int1_3907 = torch.constant.int 1 - %int8_3908 = torch.constant.int 8 - %int128_3909 = torch.constant.int 128 - %3278 = torch.prim.ListConstruct %int4_3906, %int1_3907, %int8_3908, %int128_3909 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3279 = torch.aten.view %3269, %3278 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_3910 = torch.constant.int 4 - %int1_3911 = torch.constant.int 1 - %int8_3912 = torch.constant.int 8 + %3837 = torch.aten.add.Scalar %3836, %int15_3899, %int1_3900 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_3901 = torch.constant.int 2 + %3838 = torch.aten.mul.Scalar %3837, %int2_3901 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_3902 = torch.constant.int 1 + %3839 = torch.aten.add.Tensor %3838, %3835, %int1_3902 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_3903 = torch.constant.int 8 + %3840 = torch.aten.mul.Scalar %3839, %int8_3903 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_3904 = torch.constant.int 1 + %3841 = torch.aten.add.Tensor %3840, %3800, %int1_3904 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_3905 = torch.constant.int 32 + %3842 = torch.aten.mul.Scalar %3841, %int32_3905 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_3906 = torch.constant.int 1 + %3843 = torch.aten.add.Tensor %3842, %3797, %int1_3906 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_3907 = torch.constant.int 5 + %3844 = torch.prims.convert_element_type %3769, %int5_3907 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %3845 = torch.prim.ListConstruct %3843 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_3908 = torch.constant.bool false + %3846 = torch.aten.index_put %3829, %3845, %3844, %false_3908 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %3846, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_3909 = torch.constant.int 32 + %int2_3910 = torch.constant.int 2 + %int8_3911 = torch.constant.int 8 + %int32_3912 = torch.constant.int 32 %int128_3913 = torch.constant.int 128 - %3280 = torch.prim.ListConstruct %int4_3910, %int1_3911, %int8_3912, %int128_3913 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3281 = torch.aten.view %3275, %3280 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_3914 = torch.constant.int 6 - %3282 = torch.prims.convert_element_type %3277, %int6_3914 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %3283 = torch_c.to_builtin_tensor %3282 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %3284 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %3285 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%3283, %3284) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %3286 = torch_c.from_builtin_tensor %3285 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_3915 = torch.constant.int 5 - %3287 = torch.prims.convert_element_type %3286, %int5_3915 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_3916 = torch.constant.int 6 - %3288 = torch.prims.convert_element_type %3279, %int6_3916 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %3289 = torch_c.to_builtin_tensor %3288 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %3290 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %3291 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%3289, %3290) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %3292 = torch_c.from_builtin_tensor %3291 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_3917 = torch.constant.int 5 - %3293 = torch.prims.convert_element_type %3292, %int5_3917 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %3847 = torch.prim.ListConstruct %456, %int32_3909, %int2_3910, %int8_3911, %int32_3912, %int128_3913 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3848 = torch.aten.view %3846, %3847 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3848, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_3914 = torch.constant.int 2097152 + %3849 = torch.prim.ListConstruct %456, %int2097152_3914 : (!torch.int, !torch.int) -> !torch.list + %3850 = torch.aten.view %3848, %3849 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %3850, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_3915 = torch.constant.none + %3851 = torch.aten.clone %218, %none_3915 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3852 = torch.aten.detach %3851 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3853 = torch.aten.detach %3852 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3854 = torch.aten.detach %3853 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_3916 = torch.constant.none + %3855 = torch.aten.clone %219, %none_3916 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3856 = torch.aten.detach %3855 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3857 = torch.aten.detach %3856 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3858 = torch.aten.detach %3857 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_3917 = torch.constant.none + %3859 = torch.aten.clone %220, %none_3917 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %3860 = torch.aten.detach %3859 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3861 = torch.aten.detach %3860 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %3862 = torch.aten.detach %3861 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %int32_3918 = torch.constant.int 32 - %3294 = torch.aten.floor_divide.Scalar %arg2, %int32_3918 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_3919 = torch.constant.int 1 - %3295 = torch.aten.unsqueeze %3294, %int1_3919 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3920 = torch.constant.int 1 - %false_3921 = torch.constant.bool false - %3296 = torch.aten.gather %arg3, %int1_3920, %3295, %false_3921 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_3922 = torch.constant.int 32 - %3297 = torch.aten.remainder.Scalar %arg2, %int32_3922 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_3923 = torch.constant.int 1 - %3298 = torch.aten.unsqueeze %3297, %int1_3923 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_3924 = torch.constant.none - %3299 = torch.aten.clone %159, %none_3924 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_3925 = torch.constant.int 0 - %3300 = torch.aten.unsqueeze %3299, %int0_3925 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_3926 = torch.constant.int 4 - %int1_3927 = torch.constant.int 1 - %3301 = torch.prim.ListConstruct %int4_3926, %int1_3927 : (!torch.int, !torch.int) -> !torch.list - %int1_3928 = torch.constant.int 1 - %int1_3929 = torch.constant.int 1 - %3302 = torch.prim.ListConstruct %int1_3928, %int1_3929 : (!torch.int, !torch.int) -> !torch.list + %int2_3919 = torch.constant.int 2 + %int8_3920 = torch.constant.int 8 + %int32_3921 = torch.constant.int 32 + %int128_3922 = torch.constant.int 128 + %3863 = torch.prim.ListConstruct %456, %int32_3918, %int2_3919, %int8_3920, %int32_3921, %int128_3922 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3864 = torch.aten.view %3850, %3863 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %3864, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %3865 = torch_c.to_builtin_tensor %3864 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %3866 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_3923 = tensor.cast %3866 : tensor<4x?xi64> to tensor + %3867 = torch_c.to_builtin_tensor %3854 : !torch.vtensor<[],si64> -> tensor + %3868 = torch_c.to_builtin_tensor %3858 : !torch.vtensor<[],si64> -> tensor + %3869 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%3865, %cast_3923, %3867, %3868) : (tensor, tensor, tensor, tensor) -> tensor + %cast_3924 = tensor.cast %3869 : tensor to tensor<4x?x8x32x128xf16> + %3870 = torch_c.from_builtin_tensor %cast_3924 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %3870, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %3871 = torch_c.to_builtin_tensor %3864 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %3872 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_3925 = tensor.cast %3872 : tensor<4x?xi64> to tensor + %3873 = torch_c.to_builtin_tensor %3854 : !torch.vtensor<[],si64> -> tensor + %3874 = torch_c.to_builtin_tensor %3862 : !torch.vtensor<[],si64> -> tensor + %3875 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%3871, %cast_3925, %3873, %3874) : (tensor, tensor, tensor, tensor) -> tensor + %cast_3926 = tensor.cast %3875 : tensor to tensor<4x?x8x32x128xf16> + %3876 = torch_c.from_builtin_tensor %cast_3926 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %3876, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_3927 = torch.constant.int 2 + %int3_3928 = torch.constant.int 3 + %3877 = torch.aten.transpose.int %3870, %int2_3927, %int3_3928 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3877, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_3929 = torch.constant.int 0 + %3878 = torch.aten.clone %3877, %int0_3929 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3878, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int4_3930 = torch.constant.int 4 - %int0_3931 = torch.constant.int 0 - %cpu_3932 = torch.constant.device "cpu" - %false_3933 = torch.constant.bool false - %3303 = torch.aten.empty_strided %3301, %3302, %int4_3930, %int0_3931, %cpu_3932, %false_3933 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int14 = torch.constant.int 14 - %3304 = torch.aten.fill.Scalar %3303, %int14 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_3934 = torch.constant.int 4 - %int1_3935 = torch.constant.int 1 - %3305 = torch.prim.ListConstruct %int4_3934, %int1_3935 : (!torch.int, !torch.int) -> !torch.list - %3306 = torch.aten.repeat %3300, %3305 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_3936 = torch.constant.int 32 - %3307 = torch.aten.mul.Scalar %3296, %int32_3936 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3937 = torch.constant.int 1 - %3308 = torch.aten.add.Tensor %3307, %3304, %int1_3937 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_3938 = torch.constant.int 2 - %3309 = torch.aten.mul.Scalar %3308, %int2_3938 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3939 = torch.constant.int 1 - %3310 = torch.aten.add.Tensor %3309, %3306, %int1_3939 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_3940 = torch.constant.int 32 - %3311 = torch.aten.mul.Scalar %3310, %int32_3940 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3941 = torch.constant.int 1 - %3312 = torch.aten.add.Tensor %3311, %3298, %int1_3941 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_3942 = torch.constant.int 32 - %int2_3943 = torch.constant.int 2 - %int32_3944 = torch.constant.int 32 - %int8_3945 = torch.constant.int 8 - %int128_3946 = torch.constant.int 128 - %3313 = torch.prim.ListConstruct %437, %int32_3942, %int2_3943, %int32_3944, %int8_3945, %int128_3946 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3314 = torch.aten.view %3150, %3313 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3314, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> + %int8_3931 = torch.constant.int 8 + %int128_3932 = torch.constant.int 128 + %3879 = torch.prim.ListConstruct %int4_3930, %457, %int8_3931, %int128_3932 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3880 = torch.aten._unsafe_view %3878, %3879 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3880, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_3933 = torch.constant.int 2 + %int3_3934 = torch.constant.int 3 + %3881 = torch.aten.transpose.int %3876, %int2_3933, %int3_3934 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3881, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_3935 = torch.constant.int 0 + %3882 = torch.aten.clone %3881, %int0_3935 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %3882, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_3936 = torch.constant.int 4 + %int8_3937 = torch.constant.int 8 + %int128_3938 = torch.constant.int 128 + %3883 = torch.prim.ListConstruct %int4_3936, %457, %int8_3937, %int128_3938 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3884 = torch.aten._unsafe_view %3882, %3883 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %3884, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_3939 = torch.constant.int -2 + %3885 = torch.aten.unsqueeze %3880, %int-2_3939 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3885, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_3940 = torch.constant.int 4 + %int8_3941 = torch.constant.int 8 + %int4_3942 = torch.constant.int 4 + %int128_3943 = torch.constant.int 128 + %3886 = torch.prim.ListConstruct %int4_3940, %457, %int8_3941, %int4_3942, %int128_3943 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_3944 = torch.constant.bool false + %3887 = torch.aten.expand %3885, %3886, %false_3944 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3887, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_3945 = torch.constant.int 0 + %3888 = torch.aten.clone %3887, %int0_3945 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3888, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_3946 = torch.constant.int 4 %int32_3947 = torch.constant.int 32 - %3315 = torch.aten.mul.int %437, %int32_3947 : !torch.int, !torch.int -> !torch.int - %int2_3948 = torch.constant.int 2 - %3316 = torch.aten.mul.int %3315, %int2_3948 : !torch.int, !torch.int -> !torch.int - %int32_3949 = torch.constant.int 32 - %3317 = torch.aten.mul.int %3316, %int32_3949 : !torch.int, !torch.int -> !torch.int - %int8_3950 = torch.constant.int 8 - %int128_3951 = torch.constant.int 128 - %3318 = torch.prim.ListConstruct %3317, %int8_3950, %int128_3951 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3319 = torch.aten.view %3314, %3318 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3319, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %3320 = torch.prim.ListConstruct %3312 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_3952 = torch.constant.bool false - %3321 = torch.aten.index_put %3319, %3320, %3293, %false_3952 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3321, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_3953 = torch.constant.int 32 - %int2_3954 = torch.constant.int 2 - %int32_3955 = torch.constant.int 32 - %int8_3956 = torch.constant.int 8 - %int128_3957 = torch.constant.int 128 - %3322 = torch.prim.ListConstruct %437, %int32_3953, %int2_3954, %int32_3955, %int8_3956, %int128_3957 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3323 = torch.aten.view %3321, %3322 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3323, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3958 = torch.constant.int 2097152 - %3324 = torch.prim.ListConstruct %437, %int2097152_3958 : (!torch.int, !torch.int) -> !torch.list - %3325 = torch.aten.view %3323, %3324 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3325, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_3959 = torch.constant.int 32 + %int128_3948 = torch.constant.int 128 + %3889 = torch.prim.ListConstruct %int4_3946, %457, %int32_3947, %int128_3948 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3890 = torch.aten._unsafe_view %3888, %3889 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3890, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_3949 = torch.constant.int -2 + %3891 = torch.aten.unsqueeze %3884, %int-2_3949 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %3891, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_3950 = torch.constant.int 4 + %int8_3951 = torch.constant.int 8 + %int4_3952 = torch.constant.int 4 + %int128_3953 = torch.constant.int 128 + %3892 = torch.prim.ListConstruct %int4_3950, %457, %int8_3951, %int4_3952, %int128_3953 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_3954 = torch.constant.bool false + %3893 = torch.aten.expand %3891, %3892, %false_3954 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3893, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_3955 = torch.constant.int 0 + %3894 = torch.aten.clone %3893, %int0_3955 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %3894, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_3956 = torch.constant.int 4 + %int32_3957 = torch.constant.int 32 + %int128_3958 = torch.constant.int 128 + %3895 = torch.prim.ListConstruct %int4_3956, %457, %int32_3957, %int128_3958 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3896 = torch.aten._unsafe_view %3894, %3895 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %3896, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_3959 = torch.constant.int 1 %int2_3960 = torch.constant.int 2 - %int32_3961 = torch.constant.int 32 - %int8_3962 = torch.constant.int 8 - %int128_3963 = torch.constant.int 128 - %3326 = torch.prim.ListConstruct %437, %int32_3959, %int2_3960, %int32_3961, %int8_3962, %int128_3963 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3327 = torch.aten.view %3325, %3326 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3327, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_3964 = torch.constant.int 8 - %int128_3965 = torch.constant.int 128 - %3328 = torch.prim.ListConstruct %3317, %int8_3964, %int128_3965 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3329 = torch.aten.view %3327, %3328 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3329, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_3966 = torch.constant.int 32 - %3330 = torch.aten.floor_divide.Scalar %arg2, %int32_3966 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_3967 = torch.constant.int 1 - %3331 = torch.aten.unsqueeze %3330, %int1_3967 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %3897 = torch.aten.transpose.int %3779, %int1_3959, %int2_3960 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_3961 = torch.constant.int 1 + %int2_3962 = torch.constant.int 2 + %3898 = torch.aten.transpose.int %3890, %int1_3961, %int2_3962 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3898, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_3963 = torch.constant.int 1 + %int2_3964 = torch.constant.int 2 + %3899 = torch.aten.transpose.int %3896, %int1_3963, %int2_3964 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %3899, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_3965 = torch.constant.float 0.000000e+00 + %false_3966 = torch.constant.bool false + %none_3967 = torch.constant.none + %3900:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3897, %3898, %3899, %float0.000000e00_3965, %false_3966, %470, %none_3967) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) %int1_3968 = torch.constant.int 1 - %false_3969 = torch.constant.bool false - %3332 = torch.aten.gather %arg3, %int1_3968, %3331, %false_3969 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_3970 = torch.constant.int 32 - %3333 = torch.aten.remainder.Scalar %arg2, %int32_3970 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int2_3969 = torch.constant.int 2 + %3901 = torch.aten.transpose.int %3900#0, %int1_3968, %int2_3969 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_3970 = torch.constant.int 4 %int1_3971 = torch.constant.int 1 - %3334 = torch.aten.unsqueeze %3333, %int1_3971 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_3972 = torch.constant.none - %3335 = torch.aten.clone %160, %none_3972 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_3973 = torch.constant.int 0 - %3336 = torch.aten.unsqueeze %3335, %int0_3973 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_3974 = torch.constant.int 4 - %int1_3975 = torch.constant.int 1 - %3337 = torch.prim.ListConstruct %int4_3974, %int1_3975 : (!torch.int, !torch.int) -> !torch.list - %int1_3976 = torch.constant.int 1 - %int1_3977 = torch.constant.int 1 - %3338 = torch.prim.ListConstruct %int1_3976, %int1_3977 : (!torch.int, !torch.int) -> !torch.list + %int4096_3972 = torch.constant.int 4096 + %3902 = torch.prim.ListConstruct %int4_3970, %int1_3971, %int4096_3972 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3903 = torch.aten.view %3901, %3902 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_3973 = torch.constant.int -2 + %int-1_3974 = torch.constant.int -1 + %3904 = torch.aten.transpose.int %221, %int-2_3973, %int-1_3974 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_3975 = torch.constant.int 5 + %3905 = torch.prims.convert_element_type %3904, %int5_3975 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_3976 = torch.constant.int 4 + %int4096_3977 = torch.constant.int 4096 + %3906 = torch.prim.ListConstruct %int4_3976, %int4096_3977 : (!torch.int, !torch.int) -> !torch.list + %3907 = torch.aten.view %3903, %3906 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3908 = torch.aten.mm %3907, %3905 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_3978 = torch.constant.int 4 - %int0_3979 = torch.constant.int 0 - %cpu_3980 = torch.constant.device "cpu" - %false_3981 = torch.constant.bool false - %3339 = torch.aten.empty_strided %3337, %3338, %int4_3978, %int0_3979, %cpu_3980, %false_3981 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int14_3982 = torch.constant.int 14 - %3340 = torch.aten.fill.Scalar %3339, %int14_3982 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_3983 = torch.constant.int 4 - %int1_3984 = torch.constant.int 1 - %3341 = torch.prim.ListConstruct %int4_3983, %int1_3984 : (!torch.int, !torch.int) -> !torch.list - %3342 = torch.aten.repeat %3336, %3341 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_3985 = torch.constant.int 32 - %3343 = torch.aten.mul.Scalar %3332, %int32_3985 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3986 = torch.constant.int 1 - %3344 = torch.aten.add.Tensor %3343, %3340, %int1_3986 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_3987 = torch.constant.int 2 - %3345 = torch.aten.mul.Scalar %3344, %int2_3987 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_3979 = torch.constant.int 1 + %int4096_3980 = torch.constant.int 4096 + %3909 = torch.prim.ListConstruct %int4_3978, %int1_3979, %int4096_3980 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3910 = torch.aten.view %3908, %3909 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_3981 = torch.constant.int 1 + %3911 = torch.aten.add.Tensor %3732, %3910, %int1_3981 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_3982 = torch.constant.int 6 + %3912 = torch.prims.convert_element_type %3911, %int6_3982 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_3983 = torch.constant.int 2 + %3913 = torch.aten.pow.Tensor_Scalar %3912, %int2_3983 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_3984 = torch.constant.int -1 + %3914 = torch.prim.ListConstruct %int-1_3984 : (!torch.int) -> !torch.list + %true_3985 = torch.constant.bool true + %none_3986 = torch.constant.none + %3915 = torch.aten.mean.dim %3913, %3914, %true_3985, %none_3986 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_3987 = torch.constant.float 9.9999997473787516E-6 %int1_3988 = torch.constant.int 1 - %3346 = torch.aten.add.Tensor %3345, %3342, %int1_3988 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_3989 = torch.constant.int 32 - %3347 = torch.aten.mul.Scalar %3346, %int32_3989 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_3990 = torch.constant.int 1 - %3348 = torch.aten.add.Tensor %3347, %3334, %int1_3990 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %3349 = torch.prim.ListConstruct %3348 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_3991 = torch.constant.bool false - %3350 = torch.aten.index_put %3329, %3349, %3281, %false_3991 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3350, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_3992 = torch.constant.int 32 - %int2_3993 = torch.constant.int 2 - %int32_3994 = torch.constant.int 32 - %int8_3995 = torch.constant.int 8 - %int128_3996 = torch.constant.int 128 - %3351 = torch.prim.ListConstruct %437, %int32_3992, %int2_3993, %int32_3994, %int8_3995, %int128_3996 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3352 = torch.aten.view %3350, %3351 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3352, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_3997 = torch.constant.int 2097152 - %3353 = torch.prim.ListConstruct %437, %int2097152_3997 : (!torch.int, !torch.int) -> !torch.list - %3354 = torch.aten.view %3352, %3353 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3354, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_3998 = torch.constant.int 4 - %3355 = torch.prim.ListConstruct %int4_3998, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_3999 = torch.constant.int 1 - %3356 = torch.prim.ListConstruct %358, %int1_3999 : (!torch.int, !torch.int) -> !torch.list - %int4_4000 = torch.constant.int 4 - %int0_4001 = torch.constant.int 0 - %cpu_4002 = torch.constant.device "cpu" - %false_4003 = torch.constant.bool false - %3357 = torch.aten.empty_strided %3355, %3356, %int4_4000, %int0_4001, %cpu_4002, %false_4003 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3357, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int14_4004 = torch.constant.int 14 - %3358 = torch.aten.fill.Scalar %3357, %int14_4004 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3358, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_4005 = torch.constant.int 32 - %3359 = torch.aten.mul.Scalar %arg3, %int32_4005 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3359, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_4006 = torch.constant.int 1 - %3360 = torch.aten.add.Tensor %3359, %3358, %int1_4006 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3360, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_4007 = torch.constant.int 4 - %3361 = torch.aten.mul.int %int4_4007, %358 : !torch.int, !torch.int -> !torch.int - %3362 = torch.prim.ListConstruct %3361 : (!torch.int) -> !torch.list - %3363 = torch.aten.view %3360, %3362 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3363, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_4008 = torch.constant.int 32 - %int2_4009 = torch.constant.int 2 - %int32_4010 = torch.constant.int 32 - %int8_4011 = torch.constant.int 8 - %int128_4012 = torch.constant.int 128 - %3364 = torch.prim.ListConstruct %437, %int32_4008, %int2_4009, %int32_4010, %int8_4011, %int128_4012 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3365 = torch.aten.view %3354, %3364 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3365, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4013 = torch.constant.int 32 - %3366 = torch.aten.mul.int %437, %int32_4013 : !torch.int, !torch.int -> !torch.int - %int2_4014 = torch.constant.int 2 - %int32_4015 = torch.constant.int 32 - %int8_4016 = torch.constant.int 8 - %int128_4017 = torch.constant.int 128 - %3367 = torch.prim.ListConstruct %3366, %int2_4014, %int32_4015, %int8_4016, %int128_4017 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3368 = torch.aten.view %3365, %3367 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %3368, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_4018 = torch.constant.int 0 - %3369 = torch.aten.index_select %3368, %int0_4018, %3363 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %3369, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_4019 = torch.constant.int 4 - %int2_4020 = torch.constant.int 2 - %int32_4021 = torch.constant.int 32 - %int8_4022 = torch.constant.int 8 - %int128_4023 = torch.constant.int 128 - %3370 = torch.prim.ListConstruct %int4_4019, %358, %int2_4020, %int32_4021, %int8_4022, %int128_4023 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3371 = torch.aten.view %3369, %3370 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3371, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_4024 = torch.constant.int 0 - %int0_4025 = torch.constant.int 0 - %int9223372036854775807_4026 = torch.constant.int 9223372036854775807 - %int1_4027 = torch.constant.int 1 - %3372 = torch.aten.slice.Tensor %3371, %int0_4024, %int0_4025, %int9223372036854775807_4026, %int1_4027 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3372, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_4028 = torch.constant.int 1 - %int0_4029 = torch.constant.int 0 - %int9223372036854775807_4030 = torch.constant.int 9223372036854775807 + %3916 = torch.aten.add.Scalar %3915, %float9.999990e-06_3987, %int1_3988 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %3917 = torch.aten.rsqrt %3916 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %3918 = torch.aten.mul.Tensor %3912, %3917 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_3989 = torch.constant.int 5 + %3919 = torch.prims.convert_element_type %3918, %int5_3989 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %3920 = torch.aten.mul.Tensor %222, %3919 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_3990 = torch.constant.int 5 + %3921 = torch.prims.convert_element_type %3920, %int5_3990 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_3991 = torch.constant.int -2 + %int-1_3992 = torch.constant.int -1 + %3922 = torch.aten.transpose.int %223, %int-2_3991, %int-1_3992 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_3993 = torch.constant.int 5 + %3923 = torch.prims.convert_element_type %3922, %int5_3993 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_3994 = torch.constant.int 4 + %int4096_3995 = torch.constant.int 4096 + %3924 = torch.prim.ListConstruct %int4_3994, %int4096_3995 : (!torch.int, !torch.int) -> !torch.list + %3925 = torch.aten.view %3921, %3924 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3926 = torch.aten.mm %3925, %3923 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_3996 = torch.constant.int 4 + %int1_3997 = torch.constant.int 1 + %int14336_3998 = torch.constant.int 14336 + %3927 = torch.prim.ListConstruct %int4_3996, %int1_3997, %int14336_3998 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3928 = torch.aten.view %3926, %3927 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %3929 = torch.aten.silu %3928 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_3999 = torch.constant.int -2 + %int-1_4000 = torch.constant.int -1 + %3930 = torch.aten.transpose.int %224, %int-2_3999, %int-1_4000 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_4001 = torch.constant.int 5 + %3931 = torch.prims.convert_element_type %3930, %int5_4001 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_4002 = torch.constant.int 4 + %int4096_4003 = torch.constant.int 4096 + %3932 = torch.prim.ListConstruct %int4_4002, %int4096_4003 : (!torch.int, !torch.int) -> !torch.list + %3933 = torch.aten.view %3921, %3932 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3934 = torch.aten.mm %3933, %3931 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_4004 = torch.constant.int 4 + %int1_4005 = torch.constant.int 1 + %int14336_4006 = torch.constant.int 14336 + %3935 = torch.prim.ListConstruct %int4_4004, %int1_4005, %int14336_4006 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3936 = torch.aten.view %3934, %3935 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %3937 = torch.aten.mul.Tensor %3929, %3936 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_4007 = torch.constant.int -2 + %int-1_4008 = torch.constant.int -1 + %3938 = torch.aten.transpose.int %225, %int-2_4007, %int-1_4008 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_4009 = torch.constant.int 5 + %3939 = torch.prims.convert_element_type %3938, %int5_4009 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_4010 = torch.constant.int 4 + %int14336_4011 = torch.constant.int 14336 + %3940 = torch.prim.ListConstruct %int4_4010, %int14336_4011 : (!torch.int, !torch.int) -> !torch.list + %3941 = torch.aten.view %3937, %3940 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %3942 = torch.aten.mm %3941, %3939 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_4012 = torch.constant.int 4 + %int1_4013 = torch.constant.int 1 + %int4096_4014 = torch.constant.int 4096 + %3943 = torch.prim.ListConstruct %int4_4012, %int1_4013, %int4096_4014 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3944 = torch.aten.view %3942, %3943 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_4015 = torch.constant.int 1 + %3945 = torch.aten.add.Tensor %3911, %3944, %int1_4015 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_4016 = torch.constant.int 6 + %3946 = torch.prims.convert_element_type %3945, %int6_4016 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_4017 = torch.constant.int 2 + %3947 = torch.aten.pow.Tensor_Scalar %3946, %int2_4017 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_4018 = torch.constant.int -1 + %3948 = torch.prim.ListConstruct %int-1_4018 : (!torch.int) -> !torch.list + %true_4019 = torch.constant.bool true + %none_4020 = torch.constant.none + %3949 = torch.aten.mean.dim %3947, %3948, %true_4019, %none_4020 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_4021 = torch.constant.float 9.9999997473787516E-6 + %int1_4022 = torch.constant.int 1 + %3950 = torch.aten.add.Scalar %3949, %float9.999990e-06_4021, %int1_4022 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %3951 = torch.aten.rsqrt %3950 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %3952 = torch.aten.mul.Tensor %3946, %3951 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_4023 = torch.constant.int 5 + %3953 = torch.prims.convert_element_type %3952, %int5_4023 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %3954 = torch.aten.mul.Tensor %226, %3953 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_4024 = torch.constant.int 5 + %3955 = torch.prims.convert_element_type %3954, %int5_4024 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_4025 = torch.constant.int -2 + %int-1_4026 = torch.constant.int -1 + %3956 = torch.aten.transpose.int %227, %int-2_4025, %int-1_4026 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_4027 = torch.constant.int 5 + %3957 = torch.prims.convert_element_type %3956, %int5_4027 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_4028 = torch.constant.int 4 + %int4096_4029 = torch.constant.int 4096 + %3958 = torch.prim.ListConstruct %int4_4028, %int4096_4029 : (!torch.int, !torch.int) -> !torch.list + %3959 = torch.aten.view %3955, %3958 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3960 = torch.aten.mm %3959, %3957 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_4030 = torch.constant.int 4 %int1_4031 = torch.constant.int 1 - %3373 = torch.aten.slice.Tensor %3372, %int1_4028, %int0_4029, %int9223372036854775807_4030, %int1_4031 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3373, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_4032 = torch.constant.int 2 - %int0_4033 = torch.constant.int 0 - %3374 = torch.aten.select.int %3373, %int2_4032, %int0_4033 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3374, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_4034 = torch.constant.int 32 - %3375 = torch.aten.mul.int %358, %int32_4034 : !torch.int, !torch.int -> !torch.int - %int2_4035 = torch.constant.int 2 - %int0_4036 = torch.constant.int 0 - %int1_4037 = torch.constant.int 1 - %3376 = torch.aten.slice.Tensor %3374, %int2_4035, %int0_4036, %3375, %int1_4037 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3376, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_4038 = torch.constant.int 0 - %3377 = torch.aten.clone %3376, %int0_4038 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3377, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4096_4032 = torch.constant.int 4096 + %3961 = torch.prim.ListConstruct %int4_4030, %int1_4031, %int4096_4032 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3962 = torch.aten.view %3960, %3961 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_4033 = torch.constant.int -2 + %int-1_4034 = torch.constant.int -1 + %3963 = torch.aten.transpose.int %228, %int-2_4033, %int-1_4034 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_4035 = torch.constant.int 5 + %3964 = torch.prims.convert_element_type %3963, %int5_4035 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_4036 = torch.constant.int 4 + %int4096_4037 = torch.constant.int 4096 + %3965 = torch.prim.ListConstruct %int4_4036, %int4096_4037 : (!torch.int, !torch.int) -> !torch.list + %3966 = torch.aten.view %3955, %3965 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3967 = torch.aten.mm %3966, %3964 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_4038 = torch.constant.int 4 %int1_4039 = torch.constant.int 1 - %3378 = torch.aten.size.int %3373, %int1_4039 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_4040 = torch.constant.int 32 - %3379 = torch.aten.mul.int %3378, %int32_4040 : !torch.int, !torch.int -> !torch.int - %int4_4041 = torch.constant.int 4 - %int8_4042 = torch.constant.int 8 - %int128_4043 = torch.constant.int 128 - %3380 = torch.prim.ListConstruct %int4_4041, %3379, %int8_4042, %int128_4043 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3381 = torch.aten._unsafe_view %3377, %3380 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3381, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_4044 = torch.constant.int 0 - %int0_4045 = torch.constant.int 0 - %int9223372036854775807_4046 = torch.constant.int 9223372036854775807 + %int1024_4040 = torch.constant.int 1024 + %3968 = torch.prim.ListConstruct %int4_4038, %int1_4039, %int1024_4040 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3969 = torch.aten.view %3967, %3968 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_4041 = torch.constant.int -2 + %int-1_4042 = torch.constant.int -1 + %3970 = torch.aten.transpose.int %229, %int-2_4041, %int-1_4042 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_4043 = torch.constant.int 5 + %3971 = torch.prims.convert_element_type %3970, %int5_4043 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_4044 = torch.constant.int 4 + %int4096_4045 = torch.constant.int 4096 + %3972 = torch.prim.ListConstruct %int4_4044, %int4096_4045 : (!torch.int, !torch.int) -> !torch.list + %3973 = torch.aten.view %3955, %3972 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %3974 = torch.aten.mm %3973, %3971 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_4046 = torch.constant.int 4 %int1_4047 = torch.constant.int 1 - %3382 = torch.aten.slice.Tensor %3381, %int0_4044, %int0_4045, %int9223372036854775807_4046, %int1_4047 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3382, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_4048 = torch.constant.int 0 - %int0_4049 = torch.constant.int 0 - %int9223372036854775807_4050 = torch.constant.int 9223372036854775807 - %int1_4051 = torch.constant.int 1 - %3383 = torch.aten.slice.Tensor %3371, %int0_4048, %int0_4049, %int9223372036854775807_4050, %int1_4051 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3383, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_4052 = torch.constant.int 1 - %int0_4053 = torch.constant.int 0 - %int9223372036854775807_4054 = torch.constant.int 9223372036854775807 - %int1_4055 = torch.constant.int 1 - %3384 = torch.aten.slice.Tensor %3383, %int1_4052, %int0_4053, %int9223372036854775807_4054, %int1_4055 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3384, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_4056 = torch.constant.int 2 - %int1_4057 = torch.constant.int 1 - %3385 = torch.aten.select.int %3384, %int2_4056, %int1_4057 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3385, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_4058 = torch.constant.int 2 - %int0_4059 = torch.constant.int 0 - %int1_4060 = torch.constant.int 1 - %3386 = torch.aten.slice.Tensor %3385, %int2_4058, %int0_4059, %3375, %int1_4060 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3386, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_4061 = torch.constant.int 0 - %3387 = torch.aten.clone %3386, %int0_4061 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3387, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_4062 = torch.constant.int 1 - %3388 = torch.aten.size.int %3384, %int1_4062 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_4063 = torch.constant.int 32 - %3389 = torch.aten.mul.int %3388, %int32_4063 : !torch.int, !torch.int -> !torch.int - %int4_4064 = torch.constant.int 4 - %int8_4065 = torch.constant.int 8 - %int128_4066 = torch.constant.int 128 - %3390 = torch.prim.ListConstruct %int4_4064, %3389, %int8_4065, %int128_4066 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3391 = torch.aten._unsafe_view %3387, %3390 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3391, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_4067 = torch.constant.int 0 - %int0_4068 = torch.constant.int 0 + %int1024_4048 = torch.constant.int 1024 + %3975 = torch.prim.ListConstruct %int4_4046, %int1_4047, %int1024_4048 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3976 = torch.aten.view %3974, %3975 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_4049 = torch.constant.int 4 + %int1_4050 = torch.constant.int 1 + %int32_4051 = torch.constant.int 32 + %int128_4052 = torch.constant.int 128 + %3977 = torch.prim.ListConstruct %int4_4049, %int1_4050, %int32_4051, %int128_4052 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3978 = torch.aten.view %3962, %3977 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_4053 = torch.constant.int 4 + %int1_4054 = torch.constant.int 1 + %int8_4055 = torch.constant.int 8 + %int128_4056 = torch.constant.int 128 + %3979 = torch.prim.ListConstruct %int4_4053, %int1_4054, %int8_4055, %int128_4056 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3980 = torch.aten.view %3969, %3979 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_4057 = torch.constant.int 4 + %int1_4058 = torch.constant.int 1 + %int8_4059 = torch.constant.int 8 + %int128_4060 = torch.constant.int 128 + %3981 = torch.prim.ListConstruct %int4_4057, %int1_4058, %int8_4059, %int128_4060 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3982 = torch.aten.view %3976, %3981 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_4061 = torch.constant.int 1 + %int2_4062 = torch.constant.int 2 + %3983 = torch.aten.transpose.int %3978, %int1_4061, %int2_4062 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %3984 = torch.aten.mul.Tensor %3983, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_4063 = torch.constant.int 3 + %int0_4064 = torch.constant.int 0 + %int64_4065 = torch.constant.int 64 + %int1_4066 = torch.constant.int 1 + %3985 = torch.aten.slice.Tensor %3983, %int3_4063, %int0_4064, %int64_4065, %int1_4066 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_4067 = torch.constant.int 3 + %int64_4068 = torch.constant.int 64 %int9223372036854775807_4069 = torch.constant.int 9223372036854775807 %int1_4070 = torch.constant.int 1 - %3392 = torch.aten.slice.Tensor %3391, %int0_4067, %int0_4068, %int9223372036854775807_4069, %int1_4070 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3392, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_4071 = torch.constant.int -2 - %3393 = torch.aten.unsqueeze %3382, %int-2_4071 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3393, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %3986 = torch.aten.slice.Tensor %3983, %int3_4067, %int64_4068, %int9223372036854775807_4069, %int1_4070 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %3987 = torch.aten.neg %3986 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %3988 = torch.prim.ListConstruct %3987, %3985 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_4071 = torch.constant.int -1 + %3989 = torch.aten.cat %3988, %int-1_4071 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %3990 = torch.aten.mul.Tensor %3989, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> %int1_4072 = torch.constant.int 1 - %3394 = torch.aten.size.int %3381, %int1_4072 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_4073 = torch.constant.int 4 - %int8_4074 = torch.constant.int 8 - %int4_4075 = torch.constant.int 4 - %int128_4076 = torch.constant.int 128 - %3395 = torch.prim.ListConstruct %int4_4073, %3394, %int8_4074, %int4_4075, %int128_4076 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4077 = torch.constant.bool false - %3396 = torch.aten.expand %3393, %3395, %false_4077 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3396, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %3991 = torch.aten.add.Tensor %3984, %3990, %int1_4072 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_4073 = torch.constant.int 1 + %int2_4074 = torch.constant.int 2 + %3992 = torch.aten.transpose.int %3991, %int1_4073, %int2_4074 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_4075 = torch.constant.int 1 + %int2_4076 = torch.constant.int 2 + %3993 = torch.aten.transpose.int %3980, %int1_4075, %int2_4076 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %3994 = torch.aten.mul.Tensor %3993, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_4077 = torch.constant.int 3 %int0_4078 = torch.constant.int 0 - %3397 = torch.aten.clone %3396, %int0_4078 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3397, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4079 = torch.constant.int 4 - %int32_4080 = torch.constant.int 32 - %int128_4081 = torch.constant.int 128 - %3398 = torch.prim.ListConstruct %int4_4079, %3394, %int32_4080, %int128_4081 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3399 = torch.aten._unsafe_view %3397, %3398 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3399, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_4082 = torch.constant.int -2 - %3400 = torch.aten.unsqueeze %3392, %int-2_4082 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3400, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_4083 = torch.constant.int 1 - %3401 = torch.aten.size.int %3391, %int1_4083 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_4084 = torch.constant.int 4 - %int8_4085 = torch.constant.int 8 - %int4_4086 = torch.constant.int 4 - %int128_4087 = torch.constant.int 128 - %3402 = torch.prim.ListConstruct %int4_4084, %3401, %int8_4085, %int4_4086, %int128_4087 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4088 = torch.constant.bool false - %3403 = torch.aten.expand %3400, %3402, %false_4088 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3403, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_4089 = torch.constant.int 0 - %3404 = torch.aten.clone %3403, %int0_4089 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3404, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4090 = torch.constant.int 4 - %int32_4091 = torch.constant.int 32 - %int128_4092 = torch.constant.int 128 - %3405 = torch.prim.ListConstruct %int4_4090, %3401, %int32_4091, %int128_4092 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3406 = torch.aten._unsafe_view %3404, %3405 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3406, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_4093 = torch.constant.int 1 - %int2_4094 = torch.constant.int 2 - %3407 = torch.aten.transpose.int %3287, %int1_4093, %int2_4094 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int64_4079 = torch.constant.int 64 + %int1_4080 = torch.constant.int 1 + %3995 = torch.aten.slice.Tensor %3993, %int3_4077, %int0_4078, %int64_4079, %int1_4080 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_4081 = torch.constant.int 3 + %int64_4082 = torch.constant.int 64 + %int9223372036854775807_4083 = torch.constant.int 9223372036854775807 + %int1_4084 = torch.constant.int 1 + %3996 = torch.aten.slice.Tensor %3993, %int3_4081, %int64_4082, %int9223372036854775807_4083, %int1_4084 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %3997 = torch.aten.neg %3996 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %3998 = torch.prim.ListConstruct %3997, %3995 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_4085 = torch.constant.int -1 + %3999 = torch.aten.cat %3998, %int-1_4085 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %4000 = torch.aten.mul.Tensor %3999, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_4086 = torch.constant.int 1 + %4001 = torch.aten.add.Tensor %3994, %4000, %int1_4086 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_4087 = torch.constant.int 1 + %int2_4088 = torch.constant.int 2 + %4002 = torch.aten.transpose.int %4001, %int1_4087, %int2_4088 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_4089 = torch.constant.int 32 + %4003 = torch.aten.floor_divide.Scalar %arg2, %int32_4089 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_4090 = torch.constant.int 1 + %4004 = torch.aten.unsqueeze %4003, %int1_4090 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_4091 = torch.constant.int 1 + %false_4092 = torch.constant.bool false + %4005 = torch.aten.gather %arg3, %int1_4091, %4004, %false_4092 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_4093 = torch.constant.int 4 + %int1_4094 = torch.constant.int 1 %int1_4095 = torch.constant.int 1 - %int2_4096 = torch.constant.int 2 - %3408 = torch.aten.transpose.int %3399, %int1_4095, %int2_4096 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3408, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_4097 = torch.constant.int 1 - %int2_4098 = torch.constant.int 2 - %3409 = torch.aten.transpose.int %3406, %int1_4097, %int2_4098 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3409, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_4099 = torch.constant.float 0.000000e+00 - %false_4100 = torch.constant.bool false + %4006 = torch.prim.ListConstruct %int4_4093, %int1_4094, %int1_4095 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4007 = torch.aten.view %4005, %4006 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_4096 = torch.constant.int 32 + %4008 = torch.aten.remainder.Scalar %arg2, %int32_4096 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_4097 = torch.constant.int 4 + %int1_4098 = torch.constant.int 1 + %int1_4099 = torch.constant.int 1 + %4009 = torch.prim.ListConstruct %int4_4097, %int1_4098, %int1_4099 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4010 = torch.aten.view %4008, %4009 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_4100 = torch.constant.int 8 %none_4101 = torch.constant.none - %3410:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3407, %3408, %3409, %float0.000000e00_4099, %false_4100, %368, %none_4101) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_4102 = torch.constant.int 1 - %int2_4103 = torch.constant.int 2 - %3411 = torch.aten.transpose.int %3410#0, %int1_4102, %int2_4103 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_4104 = torch.constant.int 4 + %none_4102 = torch.constant.none + %cpu_4103 = torch.constant.device "cpu" + %false_4104 = torch.constant.bool false + %4011 = torch.aten.arange %int8_4100, %none_4101, %none_4102, %cpu_4103, %false_4104 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> %int1_4105 = torch.constant.int 1 - %int4096_4106 = torch.constant.int 4096 - %3412 = torch.prim.ListConstruct %int4_4104, %int1_4105, %int4096_4106 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3413 = torch.aten.view %3411, %3412 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_4107 = torch.constant.int -2 - %int-1_4108 = torch.constant.int -1 - %3414 = torch.aten.transpose.int %161, %int-2_4107, %int-1_4108 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_4109 = torch.constant.int 4 - %int4096_4110 = torch.constant.int 4096 - %3415 = torch.prim.ListConstruct %int4_4109, %int4096_4110 : (!torch.int, !torch.int) -> !torch.list - %3416 = torch.aten.view %3413, %3415 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3417 = torch.aten.mm %3416, %3414 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_4111 = torch.constant.int 4 - %int1_4112 = torch.constant.int 1 - %int4096_4113 = torch.constant.int 4096 - %3418 = torch.prim.ListConstruct %int4_4111, %int1_4112, %int4096_4113 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3419 = torch.aten.view %3417, %3418 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_4114 = torch.constant.int 1 - %3420 = torch.aten.add.Tensor %3247, %3419, %int1_4114 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_4115 = torch.constant.int 6 - %3421 = torch.prims.convert_element_type %3420, %int6_4115 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_4116 = torch.constant.int 2 - %3422 = torch.aten.pow.Tensor_Scalar %3421, %int2_4116 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_4117 = torch.constant.int -1 - %3423 = torch.prim.ListConstruct %int-1_4117 : (!torch.int) -> !torch.list - %true_4118 = torch.constant.bool true - %none_4119 = torch.constant.none - %3424 = torch.aten.mean.dim %3422, %3423, %true_4118, %none_4119 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_4120 = torch.constant.float 9.9999997473787516E-6 - %int1_4121 = torch.constant.int 1 - %3425 = torch.aten.add.Scalar %3424, %float9.999990e-06_4120, %int1_4121 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %3426 = torch.aten.rsqrt %3425 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %3427 = torch.aten.mul.Tensor %3421, %3426 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_4122 = torch.constant.int 5 - %3428 = torch.prims.convert_element_type %3427, %int5_4122 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %3429 = torch.aten.mul.Tensor %162, %3428 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_4123 = torch.constant.int 5 - %3430 = torch.prims.convert_element_type %3429, %int5_4123 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_4124 = torch.constant.int -2 - %int-1_4125 = torch.constant.int -1 - %3431 = torch.aten.transpose.int %163, %int-2_4124, %int-1_4125 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_4126 = torch.constant.int 4 - %int4096_4127 = torch.constant.int 4096 - %3432 = torch.prim.ListConstruct %int4_4126, %int4096_4127 : (!torch.int, !torch.int) -> !torch.list - %3433 = torch.aten.view %3430, %3432 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3434 = torch.aten.mm %3433, %3431 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_4128 = torch.constant.int 4 - %int1_4129 = torch.constant.int 1 - %int14336_4130 = torch.constant.int 14336 - %3435 = torch.prim.ListConstruct %int4_4128, %int1_4129, %int14336_4130 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3436 = torch.aten.view %3434, %3435 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %3437 = torch.aten.silu %3436 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_4131 = torch.constant.int -2 - %int-1_4132 = torch.constant.int -1 - %3438 = torch.aten.transpose.int %164, %int-2_4131, %int-1_4132 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_4133 = torch.constant.int 4 - %int4096_4134 = torch.constant.int 4096 - %3439 = torch.prim.ListConstruct %int4_4133, %int4096_4134 : (!torch.int, !torch.int) -> !torch.list - %3440 = torch.aten.view %3430, %3439 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3441 = torch.aten.mm %3440, %3438 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_4135 = torch.constant.int 4 - %int1_4136 = torch.constant.int 1 - %int14336_4137 = torch.constant.int 14336 - %3442 = torch.prim.ListConstruct %int4_4135, %int1_4136, %int14336_4137 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3443 = torch.aten.view %3441, %3442 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %3444 = torch.aten.mul.Tensor %3437, %3443 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_4138 = torch.constant.int -2 - %int-1_4139 = torch.constant.int -1 - %3445 = torch.aten.transpose.int %165, %int-2_4138, %int-1_4139 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_4140 = torch.constant.int 4 - %int14336_4141 = torch.constant.int 14336 - %3446 = torch.prim.ListConstruct %int4_4140, %int14336_4141 : (!torch.int, !torch.int) -> !torch.list - %3447 = torch.aten.view %3444, %3446 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %3448 = torch.aten.mm %3447, %3445 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_4142 = torch.constant.int 4 + %int1_4106 = torch.constant.int 1 + %int8_4107 = torch.constant.int 8 + %4012 = torch.prim.ListConstruct %int1_4105, %int1_4106, %int8_4107 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4013 = torch.aten.view %4011, %4012 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_4108 = torch.constant.none + %4014 = torch.aten.clone %230, %none_4108 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4015 = torch.aten.detach %4014 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4016 = torch.aten.detach %4015 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4017 = torch.aten.detach %4016 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_4109 = torch.constant.int 1 + %int1_4110 = torch.constant.int 1 + %int1_4111 = torch.constant.int 1 + %4018 = torch.prim.ListConstruct %int1_4109, %int1_4110, %int1_4111 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4019 = torch.aten.view %4017, %4018 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_4112 = torch.constant.int 32 + %4020 = torch.aten.mul.Scalar %4007, %int32_4112 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int16 = torch.constant.int 16 + %int1_4113 = torch.constant.int 1 + %4021 = torch.aten.add.Scalar %4020, %int16, %int1_4113 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_4114 = torch.constant.int 2 + %4022 = torch.aten.mul.Scalar %4021, %int2_4114 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4115 = torch.constant.int 1 + %4023 = torch.aten.add.Tensor %4022, %4019, %int1_4115 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_4116 = torch.constant.int 8 + %4024 = torch.aten.mul.Scalar %4023, %int8_4116 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4117 = torch.constant.int 1 + %4025 = torch.aten.add.Tensor %4024, %4013, %int1_4117 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_4118 = torch.constant.int 32 + %4026 = torch.aten.mul.Scalar %4025, %int32_4118 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_4119 = torch.constant.int 1 + %4027 = torch.aten.add.Tensor %4026, %4010, %int1_4119 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_4120 = torch.constant.int 5 + %4028 = torch.prims.convert_element_type %4002, %int5_4120 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_4121 = torch.constant.int 32 + %int2_4122 = torch.constant.int 2 + %int8_4123 = torch.constant.int 8 + %int32_4124 = torch.constant.int 32 + %int128_4125 = torch.constant.int 128 + %4029 = torch.prim.ListConstruct %456, %int32_4121, %int2_4122, %int8_4123, %int32_4124, %int128_4125 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4030 = torch.aten.view %3850, %4029 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4030, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_4126 = torch.constant.int 128 + %4031 = torch.prim.ListConstruct %596, %int128_4126 : (!torch.int, !torch.int) -> !torch.list + %4032 = torch.aten.view %4030, %4031 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4032, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %4033 = torch.prim.ListConstruct %4027 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_4127 = torch.constant.bool false + %4034 = torch.aten.index_put %4032, %4033, %4028, %false_4127 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4034, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_4128 = torch.constant.int 32 + %int2_4129 = torch.constant.int 2 + %int8_4130 = torch.constant.int 8 + %int32_4131 = torch.constant.int 32 + %int128_4132 = torch.constant.int 128 + %4035 = torch.prim.ListConstruct %456, %int32_4128, %int2_4129, %int8_4130, %int32_4131, %int128_4132 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4036 = torch.aten.view %4034, %4035 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4036, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_4133 = torch.constant.int 2097152 + %4037 = torch.prim.ListConstruct %456, %int2097152_4133 : (!torch.int, !torch.int) -> !torch.list + %4038 = torch.aten.view %4036, %4037 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4038, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_4134 = torch.constant.int 32 + %int2_4135 = torch.constant.int 2 + %int8_4136 = torch.constant.int 8 + %int32_4137 = torch.constant.int 32 + %int128_4138 = torch.constant.int 128 + %4039 = torch.prim.ListConstruct %456, %int32_4134, %int2_4135, %int8_4136, %int32_4137, %int128_4138 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4040 = torch.aten.view %4038, %4039 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4040, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_4139 = torch.constant.int 128 + %4041 = torch.prim.ListConstruct %596, %int128_4139 : (!torch.int, !torch.int) -> !torch.list + %4042 = torch.aten.view %4040, %4041 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4042, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_4140 = torch.constant.none + %4043 = torch.aten.clone %231, %none_4140 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4044 = torch.aten.detach %4043 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4045 = torch.aten.detach %4044 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4046 = torch.aten.detach %4045 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_4141 = torch.constant.int 1 + %int1_4142 = torch.constant.int 1 %int1_4143 = torch.constant.int 1 - %int4096_4144 = torch.constant.int 4096 - %3449 = torch.prim.ListConstruct %int4_4142, %int1_4143, %int4096_4144 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3450 = torch.aten.view %3448, %3449 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_4145 = torch.constant.int 1 - %3451 = torch.aten.add.Tensor %3420, %3450, %int1_4145 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_4146 = torch.constant.int 6 - %3452 = torch.prims.convert_element_type %3451, %int6_4146 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %4047 = torch.prim.ListConstruct %int1_4141, %int1_4142, %int1_4143 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4048 = torch.aten.view %4046, %4047 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_4144 = torch.constant.int 32 + %4049 = torch.aten.mul.Scalar %4007, %int32_4144 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int16_4145 = torch.constant.int 16 + %int1_4146 = torch.constant.int 1 + %4050 = torch.aten.add.Scalar %4049, %int16_4145, %int1_4146 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> %int2_4147 = torch.constant.int 2 - %3453 = torch.aten.pow.Tensor_Scalar %3452, %int2_4147 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_4148 = torch.constant.int -1 - %3454 = torch.prim.ListConstruct %int-1_4148 : (!torch.int) -> !torch.list - %true_4149 = torch.constant.bool true - %none_4150 = torch.constant.none - %3455 = torch.aten.mean.dim %3453, %3454, %true_4149, %none_4150 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_4151 = torch.constant.float 9.9999997473787516E-6 + %4051 = torch.aten.mul.Scalar %4050, %int2_4147 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4148 = torch.constant.int 1 + %4052 = torch.aten.add.Tensor %4051, %4048, %int1_4148 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_4149 = torch.constant.int 8 + %4053 = torch.aten.mul.Scalar %4052, %int8_4149 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4150 = torch.constant.int 1 + %4054 = torch.aten.add.Tensor %4053, %4013, %int1_4150 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_4151 = torch.constant.int 32 + %4055 = torch.aten.mul.Scalar %4054, %int32_4151 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_4152 = torch.constant.int 1 - %3456 = torch.aten.add.Scalar %3455, %float9.999990e-06_4151, %int1_4152 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %3457 = torch.aten.rsqrt %3456 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %3458 = torch.aten.mul.Tensor %3452, %3457 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %4056 = torch.aten.add.Tensor %4055, %4010, %int1_4152 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int5_4153 = torch.constant.int 5 - %3459 = torch.prims.convert_element_type %3458, %int5_4153 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %3460 = torch.aten.mul.Tensor %166, %3459 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_4154 = torch.constant.int 5 - %3461 = torch.prims.convert_element_type %3460, %int5_4154 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_4155 = torch.constant.int -2 - %int-1_4156 = torch.constant.int -1 - %3462 = torch.aten.transpose.int %167, %int-2_4155, %int-1_4156 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_4157 = torch.constant.int 4 - %int4096_4158 = torch.constant.int 4096 - %3463 = torch.prim.ListConstruct %int4_4157, %int4096_4158 : (!torch.int, !torch.int) -> !torch.list - %3464 = torch.aten.view %3461, %3463 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3465 = torch.aten.mm %3464, %3462 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_4159 = torch.constant.int 4 - %int1_4160 = torch.constant.int 1 - %int4096_4161 = torch.constant.int 4096 - %3466 = torch.prim.ListConstruct %int4_4159, %int1_4160, %int4096_4161 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3467 = torch.aten.view %3465, %3466 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_4162 = torch.constant.int -2 - %int-1_4163 = torch.constant.int -1 - %3468 = torch.aten.transpose.int %168, %int-2_4162, %int-1_4163 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_4164 = torch.constant.int 4 - %int4096_4165 = torch.constant.int 4096 - %3469 = torch.prim.ListConstruct %int4_4164, %int4096_4165 : (!torch.int, !torch.int) -> !torch.list - %3470 = torch.aten.view %3461, %3469 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3471 = torch.aten.mm %3470, %3468 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_4166 = torch.constant.int 4 - %int1_4167 = torch.constant.int 1 - %int1024_4168 = torch.constant.int 1024 - %3472 = torch.prim.ListConstruct %int4_4166, %int1_4167, %int1024_4168 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3473 = torch.aten.view %3471, %3472 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_4169 = torch.constant.int -2 - %int-1_4170 = torch.constant.int -1 - %3474 = torch.aten.transpose.int %169, %int-2_4169, %int-1_4170 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_4171 = torch.constant.int 4 - %int4096_4172 = torch.constant.int 4096 - %3475 = torch.prim.ListConstruct %int4_4171, %int4096_4172 : (!torch.int, !torch.int) -> !torch.list - %3476 = torch.aten.view %3461, %3475 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3477 = torch.aten.mm %3476, %3474 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_4173 = torch.constant.int 4 - %int1_4174 = torch.constant.int 1 - %int1024_4175 = torch.constant.int 1024 - %3478 = torch.prim.ListConstruct %int4_4173, %int1_4174, %int1024_4175 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3479 = torch.aten.view %3477, %3478 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %4057 = torch.prims.convert_element_type %3982, %int5_4153 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %4058 = torch.prim.ListConstruct %4056 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_4154 = torch.constant.bool false + %4059 = torch.aten.index_put %4042, %4058, %4057, %false_4154 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4059, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_4155 = torch.constant.int 32 + %int2_4156 = torch.constant.int 2 + %int8_4157 = torch.constant.int 8 + %int32_4158 = torch.constant.int 32 + %int128_4159 = torch.constant.int 128 + %4060 = torch.prim.ListConstruct %456, %int32_4155, %int2_4156, %int8_4157, %int32_4158, %int128_4159 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4061 = torch.aten.view %4059, %4060 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4061, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_4160 = torch.constant.int 2097152 + %4062 = torch.prim.ListConstruct %456, %int2097152_4160 : (!torch.int, !torch.int) -> !torch.list + %4063 = torch.aten.view %4061, %4062 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4063, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_4161 = torch.constant.none + %4064 = torch.aten.clone %232, %none_4161 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4065 = torch.aten.detach %4064 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4066 = torch.aten.detach %4065 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4067 = torch.aten.detach %4066 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_4162 = torch.constant.none + %4068 = torch.aten.clone %233, %none_4162 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4069 = torch.aten.detach %4068 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4070 = torch.aten.detach %4069 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4071 = torch.aten.detach %4070 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_4163 = torch.constant.none + %4072 = torch.aten.clone %234, %none_4163 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4073 = torch.aten.detach %4072 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4074 = torch.aten.detach %4073 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4075 = torch.aten.detach %4074 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_4164 = torch.constant.int 32 + %int2_4165 = torch.constant.int 2 + %int8_4166 = torch.constant.int 8 + %int32_4167 = torch.constant.int 32 + %int128_4168 = torch.constant.int 128 + %4076 = torch.prim.ListConstruct %456, %int32_4164, %int2_4165, %int8_4166, %int32_4167, %int128_4168 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4077 = torch.aten.view %4063, %4076 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4077, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %4078 = torch_c.to_builtin_tensor %4077 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %4079 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_4169 = tensor.cast %4079 : tensor<4x?xi64> to tensor + %4080 = torch_c.to_builtin_tensor %4067 : !torch.vtensor<[],si64> -> tensor + %4081 = torch_c.to_builtin_tensor %4071 : !torch.vtensor<[],si64> -> tensor + %4082 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%4078, %cast_4169, %4080, %4081) : (tensor, tensor, tensor, tensor) -> tensor + %cast_4170 = tensor.cast %4082 : tensor to tensor<4x?x8x32x128xf16> + %4083 = torch_c.from_builtin_tensor %cast_4170 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %4083, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %4084 = torch_c.to_builtin_tensor %4077 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %4085 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_4171 = tensor.cast %4085 : tensor<4x?xi64> to tensor + %4086 = torch_c.to_builtin_tensor %4067 : !torch.vtensor<[],si64> -> tensor + %4087 = torch_c.to_builtin_tensor %4075 : !torch.vtensor<[],si64> -> tensor + %4088 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%4084, %cast_4171, %4086, %4087) : (tensor, tensor, tensor, tensor) -> tensor + %cast_4172 = tensor.cast %4088 : tensor to tensor<4x?x8x32x128xf16> + %4089 = torch_c.from_builtin_tensor %cast_4172 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %4089, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_4173 = torch.constant.int 2 + %int3_4174 = torch.constant.int 3 + %4090 = torch.aten.transpose.int %4083, %int2_4173, %int3_4174 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4090, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_4175 = torch.constant.int 0 + %4091 = torch.aten.clone %4090, %int0_4175 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4091, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int4_4176 = torch.constant.int 4 - %int1_4177 = torch.constant.int 1 - %int32_4178 = torch.constant.int 32 - %int128_4179 = torch.constant.int 128 - %3480 = torch.prim.ListConstruct %int4_4176, %int1_4177, %int32_4178, %int128_4179 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3481 = torch.aten.view %3467, %3480 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_4180 = torch.constant.int 4 - %int1_4181 = torch.constant.int 1 - %int8_4182 = torch.constant.int 8 - %int128_4183 = torch.constant.int 128 - %3482 = torch.prim.ListConstruct %int4_4180, %int1_4181, %int8_4182, %int128_4183 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3483 = torch.aten.view %3473, %3482 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_4184 = torch.constant.int 4 - %int1_4185 = torch.constant.int 1 - %int8_4186 = torch.constant.int 8 - %int128_4187 = torch.constant.int 128 - %3484 = torch.prim.ListConstruct %int4_4184, %int1_4185, %int8_4186, %int128_4187 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3485 = torch.aten.view %3479, %3484 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_4188 = torch.constant.int 6 - %3486 = torch.prims.convert_element_type %3481, %int6_4188 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %3487 = torch_c.to_builtin_tensor %3486 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %3488 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %3489 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%3487, %3488) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %3490 = torch_c.from_builtin_tensor %3489 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_4189 = torch.constant.int 5 - %3491 = torch.prims.convert_element_type %3490, %int5_4189 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_4190 = torch.constant.int 6 - %3492 = torch.prims.convert_element_type %3483, %int6_4190 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %3493 = torch_c.to_builtin_tensor %3492 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %3494 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %3495 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%3493, %3494) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %3496 = torch_c.from_builtin_tensor %3495 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_4191 = torch.constant.int 5 - %3497 = torch.prims.convert_element_type %3496, %int5_4191 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_4192 = torch.constant.int 32 - %3498 = torch.aten.floor_divide.Scalar %arg2, %int32_4192 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_4193 = torch.constant.int 1 - %3499 = torch.aten.unsqueeze %3498, %int1_4193 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4194 = torch.constant.int 1 - %false_4195 = torch.constant.bool false - %3500 = torch.aten.gather %arg3, %int1_4194, %3499, %false_4195 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_4196 = torch.constant.int 32 - %3501 = torch.aten.remainder.Scalar %arg2, %int32_4196 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_4197 = torch.constant.int 1 - %3502 = torch.aten.unsqueeze %3501, %int1_4197 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_4198 = torch.constant.none - %3503 = torch.aten.clone %170, %none_4198 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_4199 = torch.constant.int 0 - %3504 = torch.aten.unsqueeze %3503, %int0_4199 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_4200 = torch.constant.int 4 - %int1_4201 = torch.constant.int 1 - %3505 = torch.prim.ListConstruct %int4_4200, %int1_4201 : (!torch.int, !torch.int) -> !torch.list - %int1_4202 = torch.constant.int 1 - %int1_4203 = torch.constant.int 1 - %3506 = torch.prim.ListConstruct %int1_4202, %int1_4203 : (!torch.int, !torch.int) -> !torch.list - %int4_4204 = torch.constant.int 4 - %int0_4205 = torch.constant.int 0 - %cpu_4206 = torch.constant.device "cpu" - %false_4207 = torch.constant.bool false - %3507 = torch.aten.empty_strided %3505, %3506, %int4_4204, %int0_4205, %cpu_4206, %false_4207 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int15 = torch.constant.int 15 - %3508 = torch.aten.fill.Scalar %3507, %int15 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_4208 = torch.constant.int 4 + %int8_4177 = torch.constant.int 8 + %int128_4178 = torch.constant.int 128 + %4092 = torch.prim.ListConstruct %int4_4176, %457, %int8_4177, %int128_4178 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4093 = torch.aten._unsafe_view %4091, %4092 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4093, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_4179 = torch.constant.int 2 + %int3_4180 = torch.constant.int 3 + %4094 = torch.aten.transpose.int %4089, %int2_4179, %int3_4180 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4094, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_4181 = torch.constant.int 0 + %4095 = torch.aten.clone %4094, %int0_4181 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4095, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_4182 = torch.constant.int 4 + %int8_4183 = torch.constant.int 8 + %int128_4184 = torch.constant.int 128 + %4096 = torch.prim.ListConstruct %int4_4182, %457, %int8_4183, %int128_4184 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4097 = torch.aten._unsafe_view %4095, %4096 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4097, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_4185 = torch.constant.int -2 + %4098 = torch.aten.unsqueeze %4093, %int-2_4185 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4098, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_4186 = torch.constant.int 4 + %int8_4187 = torch.constant.int 8 + %int4_4188 = torch.constant.int 4 + %int128_4189 = torch.constant.int 128 + %4099 = torch.prim.ListConstruct %int4_4186, %457, %int8_4187, %int4_4188, %int128_4189 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_4190 = torch.constant.bool false + %4100 = torch.aten.expand %4098, %4099, %false_4190 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4100, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_4191 = torch.constant.int 0 + %4101 = torch.aten.clone %4100, %int0_4191 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4101, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_4192 = torch.constant.int 4 + %int32_4193 = torch.constant.int 32 + %int128_4194 = torch.constant.int 128 + %4102 = torch.prim.ListConstruct %int4_4192, %457, %int32_4193, %int128_4194 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4103 = torch.aten._unsafe_view %4101, %4102 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4103, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_4195 = torch.constant.int -2 + %4104 = torch.aten.unsqueeze %4097, %int-2_4195 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4104, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_4196 = torch.constant.int 4 + %int8_4197 = torch.constant.int 8 + %int4_4198 = torch.constant.int 4 + %int128_4199 = torch.constant.int 128 + %4105 = torch.prim.ListConstruct %int4_4196, %457, %int8_4197, %int4_4198, %int128_4199 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_4200 = torch.constant.bool false + %4106 = torch.aten.expand %4104, %4105, %false_4200 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4106, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_4201 = torch.constant.int 0 + %4107 = torch.aten.clone %4106, %int0_4201 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4107, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_4202 = torch.constant.int 4 + %int32_4203 = torch.constant.int 32 + %int128_4204 = torch.constant.int 128 + %4108 = torch.prim.ListConstruct %int4_4202, %457, %int32_4203, %int128_4204 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4109 = torch.aten._unsafe_view %4107, %4108 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4109, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_4205 = torch.constant.int 1 + %int2_4206 = torch.constant.int 2 + %4110 = torch.aten.transpose.int %3992, %int1_4205, %int2_4206 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_4207 = torch.constant.int 1 + %int2_4208 = torch.constant.int 2 + %4111 = torch.aten.transpose.int %4103, %int1_4207, %int2_4208 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4111, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_4209 = torch.constant.int 1 - %3509 = torch.prim.ListConstruct %int4_4208, %int1_4209 : (!torch.int, !torch.int) -> !torch.list - %3510 = torch.aten.repeat %3504, %3509 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_4210 = torch.constant.int 32 - %3511 = torch.aten.mul.Scalar %3500, %int32_4210 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4211 = torch.constant.int 1 - %3512 = torch.aten.add.Tensor %3511, %3508, %int1_4211 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_4212 = torch.constant.int 2 - %3513 = torch.aten.mul.Scalar %3512, %int2_4212 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4213 = torch.constant.int 1 - %3514 = torch.aten.add.Tensor %3513, %3510, %int1_4213 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_4214 = torch.constant.int 32 - %3515 = torch.aten.mul.Scalar %3514, %int32_4214 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4215 = torch.constant.int 1 - %3516 = torch.aten.add.Tensor %3515, %3502, %int1_4215 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_4216 = torch.constant.int 32 - %int2_4217 = torch.constant.int 2 - %int32_4218 = torch.constant.int 32 - %int8_4219 = torch.constant.int 8 - %int128_4220 = torch.constant.int 128 - %3517 = torch.prim.ListConstruct %437, %int32_4216, %int2_4217, %int32_4218, %int8_4219, %int128_4220 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3518 = torch.aten.view %3354, %3517 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3518, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4221 = torch.constant.int 32 - %3519 = torch.aten.mul.int %437, %int32_4221 : !torch.int, !torch.int -> !torch.int - %int2_4222 = torch.constant.int 2 - %3520 = torch.aten.mul.int %3519, %int2_4222 : !torch.int, !torch.int -> !torch.int - %int32_4223 = torch.constant.int 32 - %3521 = torch.aten.mul.int %3520, %int32_4223 : !torch.int, !torch.int -> !torch.int - %int8_4224 = torch.constant.int 8 - %int128_4225 = torch.constant.int 128 - %3522 = torch.prim.ListConstruct %3521, %int8_4224, %int128_4225 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3523 = torch.aten.view %3518, %3522 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3523, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %3524 = torch.prim.ListConstruct %3516 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_4226 = torch.constant.bool false - %3525 = torch.aten.index_put %3523, %3524, %3497, %false_4226 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3525, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_4227 = torch.constant.int 32 - %int2_4228 = torch.constant.int 2 - %int32_4229 = torch.constant.int 32 - %int8_4230 = torch.constant.int 8 - %int128_4231 = torch.constant.int 128 - %3526 = torch.prim.ListConstruct %437, %int32_4227, %int2_4228, %int32_4229, %int8_4230, %int128_4231 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3527 = torch.aten.view %3525, %3526 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3527, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_4232 = torch.constant.int 2097152 - %3528 = torch.prim.ListConstruct %437, %int2097152_4232 : (!torch.int, !torch.int) -> !torch.list - %3529 = torch.aten.view %3527, %3528 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3529, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_4233 = torch.constant.int 32 - %int2_4234 = torch.constant.int 2 - %int32_4235 = torch.constant.int 32 - %int8_4236 = torch.constant.int 8 - %int128_4237 = torch.constant.int 128 - %3530 = torch.prim.ListConstruct %437, %int32_4233, %int2_4234, %int32_4235, %int8_4236, %int128_4237 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3531 = torch.aten.view %3529, %3530 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3531, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_4238 = torch.constant.int 8 - %int128_4239 = torch.constant.int 128 - %3532 = torch.prim.ListConstruct %3521, %int8_4238, %int128_4239 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3533 = torch.aten.view %3531, %3532 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3533, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_4240 = torch.constant.int 32 - %3534 = torch.aten.floor_divide.Scalar %arg2, %int32_4240 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_4241 = torch.constant.int 1 - %3535 = torch.aten.unsqueeze %3534, %int1_4241 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4242 = torch.constant.int 1 - %false_4243 = torch.constant.bool false - %3536 = torch.aten.gather %arg3, %int1_4242, %3535, %false_4243 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_4244 = torch.constant.int 32 - %3537 = torch.aten.remainder.Scalar %arg2, %int32_4244 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_4245 = torch.constant.int 1 - %3538 = torch.aten.unsqueeze %3537, %int1_4245 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_4246 = torch.constant.none - %3539 = torch.aten.clone %171, %none_4246 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_4247 = torch.constant.int 0 - %3540 = torch.aten.unsqueeze %3539, %int0_4247 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %int2_4210 = torch.constant.int 2 + %4112 = torch.aten.transpose.int %4109, %int1_4209, %int2_4210 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4112, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_4211 = torch.constant.float 0.000000e+00 + %false_4212 = torch.constant.bool false + %none_4213 = torch.constant.none + %4113:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4110, %4111, %4112, %float0.000000e00_4211, %false_4212, %470, %none_4213) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_4214 = torch.constant.int 1 + %int2_4215 = torch.constant.int 2 + %4114 = torch.aten.transpose.int %4113#0, %int1_4214, %int2_4215 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_4216 = torch.constant.int 4 + %int1_4217 = torch.constant.int 1 + %int4096_4218 = torch.constant.int 4096 + %4115 = torch.prim.ListConstruct %int4_4216, %int1_4217, %int4096_4218 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4116 = torch.aten.view %4114, %4115 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_4219 = torch.constant.int -2 + %int-1_4220 = torch.constant.int -1 + %4117 = torch.aten.transpose.int %235, %int-2_4219, %int-1_4220 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_4221 = torch.constant.int 5 + %4118 = torch.prims.convert_element_type %4117, %int5_4221 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_4222 = torch.constant.int 4 + %int4096_4223 = torch.constant.int 4096 + %4119 = torch.prim.ListConstruct %int4_4222, %int4096_4223 : (!torch.int, !torch.int) -> !torch.list + %4120 = torch.aten.view %4116, %4119 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4121 = torch.aten.mm %4120, %4118 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_4224 = torch.constant.int 4 + %int1_4225 = torch.constant.int 1 + %int4096_4226 = torch.constant.int 4096 + %4122 = torch.prim.ListConstruct %int4_4224, %int1_4225, %int4096_4226 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4123 = torch.aten.view %4121, %4122 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_4227 = torch.constant.int 1 + %4124 = torch.aten.add.Tensor %3945, %4123, %int1_4227 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_4228 = torch.constant.int 6 + %4125 = torch.prims.convert_element_type %4124, %int6_4228 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_4229 = torch.constant.int 2 + %4126 = torch.aten.pow.Tensor_Scalar %4125, %int2_4229 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_4230 = torch.constant.int -1 + %4127 = torch.prim.ListConstruct %int-1_4230 : (!torch.int) -> !torch.list + %true_4231 = torch.constant.bool true + %none_4232 = torch.constant.none + %4128 = torch.aten.mean.dim %4126, %4127, %true_4231, %none_4232 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_4233 = torch.constant.float 9.9999997473787516E-6 + %int1_4234 = torch.constant.int 1 + %4129 = torch.aten.add.Scalar %4128, %float9.999990e-06_4233, %int1_4234 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %4130 = torch.aten.rsqrt %4129 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %4131 = torch.aten.mul.Tensor %4125, %4130 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_4235 = torch.constant.int 5 + %4132 = torch.prims.convert_element_type %4131, %int5_4235 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %4133 = torch.aten.mul.Tensor %236, %4132 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_4236 = torch.constant.int 5 + %4134 = torch.prims.convert_element_type %4133, %int5_4236 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_4237 = torch.constant.int -2 + %int-1_4238 = torch.constant.int -1 + %4135 = torch.aten.transpose.int %237, %int-2_4237, %int-1_4238 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_4239 = torch.constant.int 5 + %4136 = torch.prims.convert_element_type %4135, %int5_4239 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_4240 = torch.constant.int 4 + %int4096_4241 = torch.constant.int 4096 + %4137 = torch.prim.ListConstruct %int4_4240, %int4096_4241 : (!torch.int, !torch.int) -> !torch.list + %4138 = torch.aten.view %4134, %4137 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4139 = torch.aten.mm %4138, %4136 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_4242 = torch.constant.int 4 + %int1_4243 = torch.constant.int 1 + %int14336_4244 = torch.constant.int 14336 + %4140 = torch.prim.ListConstruct %int4_4242, %int1_4243, %int14336_4244 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4141 = torch.aten.view %4139, %4140 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %4142 = torch.aten.silu %4141 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_4245 = torch.constant.int -2 + %int-1_4246 = torch.constant.int -1 + %4143 = torch.aten.transpose.int %238, %int-2_4245, %int-1_4246 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_4247 = torch.constant.int 5 + %4144 = torch.prims.convert_element_type %4143, %int5_4247 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> %int4_4248 = torch.constant.int 4 - %int1_4249 = torch.constant.int 1 - %3541 = torch.prim.ListConstruct %int4_4248, %int1_4249 : (!torch.int, !torch.int) -> !torch.list - %int1_4250 = torch.constant.int 1 + %int4096_4249 = torch.constant.int 4096 + %4145 = torch.prim.ListConstruct %int4_4248, %int4096_4249 : (!torch.int, !torch.int) -> !torch.list + %4146 = torch.aten.view %4134, %4145 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4147 = torch.aten.mm %4146, %4144 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_4250 = torch.constant.int 4 %int1_4251 = torch.constant.int 1 - %3542 = torch.prim.ListConstruct %int1_4250, %int1_4251 : (!torch.int, !torch.int) -> !torch.list - %int4_4252 = torch.constant.int 4 - %int0_4253 = torch.constant.int 0 - %cpu_4254 = torch.constant.device "cpu" - %false_4255 = torch.constant.bool false - %3543 = torch.aten.empty_strided %3541, %3542, %int4_4252, %int0_4253, %cpu_4254, %false_4255 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int15_4256 = torch.constant.int 15 - %3544 = torch.aten.fill.Scalar %3543, %int15_4256 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_4257 = torch.constant.int 4 - %int1_4258 = torch.constant.int 1 - %3545 = torch.prim.ListConstruct %int4_4257, %int1_4258 : (!torch.int, !torch.int) -> !torch.list - %3546 = torch.aten.repeat %3540, %3545 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_4259 = torch.constant.int 32 - %3547 = torch.aten.mul.Scalar %3536, %int32_4259 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4260 = torch.constant.int 1 - %3548 = torch.aten.add.Tensor %3547, %3544, %int1_4260 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_4261 = torch.constant.int 2 - %3549 = torch.aten.mul.Scalar %3548, %int2_4261 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4262 = torch.constant.int 1 - %3550 = torch.aten.add.Tensor %3549, %3546, %int1_4262 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_4263 = torch.constant.int 32 - %3551 = torch.aten.mul.Scalar %3550, %int32_4263 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4264 = torch.constant.int 1 - %3552 = torch.aten.add.Tensor %3551, %3538, %int1_4264 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %3553 = torch.prim.ListConstruct %3552 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_4265 = torch.constant.bool false - %3554 = torch.aten.index_put %3533, %3553, %3485, %false_4265 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3554, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_4266 = torch.constant.int 32 - %int2_4267 = torch.constant.int 2 - %int32_4268 = torch.constant.int 32 - %int8_4269 = torch.constant.int 8 - %int128_4270 = torch.constant.int 128 - %3555 = torch.prim.ListConstruct %437, %int32_4266, %int2_4267, %int32_4268, %int8_4269, %int128_4270 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3556 = torch.aten.view %3554, %3555 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3556, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_4271 = torch.constant.int 2097152 - %3557 = torch.prim.ListConstruct %437, %int2097152_4271 : (!torch.int, !torch.int) -> !torch.list - %3558 = torch.aten.view %3556, %3557 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3558, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_4272 = torch.constant.int 4 - %3559 = torch.prim.ListConstruct %int4_4272, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_4273 = torch.constant.int 1 - %3560 = torch.prim.ListConstruct %358, %int1_4273 : (!torch.int, !torch.int) -> !torch.list + %int14336_4252 = torch.constant.int 14336 + %4148 = torch.prim.ListConstruct %int4_4250, %int1_4251, %int14336_4252 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4149 = torch.aten.view %4147, %4148 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %4150 = torch.aten.mul.Tensor %4142, %4149 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_4253 = torch.constant.int -2 + %int-1_4254 = torch.constant.int -1 + %4151 = torch.aten.transpose.int %239, %int-2_4253, %int-1_4254 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_4255 = torch.constant.int 5 + %4152 = torch.prims.convert_element_type %4151, %int5_4255 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_4256 = torch.constant.int 4 + %int14336_4257 = torch.constant.int 14336 + %4153 = torch.prim.ListConstruct %int4_4256, %int14336_4257 : (!torch.int, !torch.int) -> !torch.list + %4154 = torch.aten.view %4150, %4153 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %4155 = torch.aten.mm %4154, %4152 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_4258 = torch.constant.int 4 + %int1_4259 = torch.constant.int 1 + %int4096_4260 = torch.constant.int 4096 + %4156 = torch.prim.ListConstruct %int4_4258, %int1_4259, %int4096_4260 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4157 = torch.aten.view %4155, %4156 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_4261 = torch.constant.int 1 + %4158 = torch.aten.add.Tensor %4124, %4157, %int1_4261 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_4262 = torch.constant.int 6 + %4159 = torch.prims.convert_element_type %4158, %int6_4262 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_4263 = torch.constant.int 2 + %4160 = torch.aten.pow.Tensor_Scalar %4159, %int2_4263 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_4264 = torch.constant.int -1 + %4161 = torch.prim.ListConstruct %int-1_4264 : (!torch.int) -> !torch.list + %true_4265 = torch.constant.bool true + %none_4266 = torch.constant.none + %4162 = torch.aten.mean.dim %4160, %4161, %true_4265, %none_4266 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_4267 = torch.constant.float 9.9999997473787516E-6 + %int1_4268 = torch.constant.int 1 + %4163 = torch.aten.add.Scalar %4162, %float9.999990e-06_4267, %int1_4268 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %4164 = torch.aten.rsqrt %4163 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %4165 = torch.aten.mul.Tensor %4159, %4164 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_4269 = torch.constant.int 5 + %4166 = torch.prims.convert_element_type %4165, %int5_4269 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %4167 = torch.aten.mul.Tensor %240, %4166 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_4270 = torch.constant.int 5 + %4168 = torch.prims.convert_element_type %4167, %int5_4270 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_4271 = torch.constant.int -2 + %int-1_4272 = torch.constant.int -1 + %4169 = torch.aten.transpose.int %241, %int-2_4271, %int-1_4272 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_4273 = torch.constant.int 5 + %4170 = torch.prims.convert_element_type %4169, %int5_4273 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> %int4_4274 = torch.constant.int 4 - %int0_4275 = torch.constant.int 0 - %cpu_4276 = torch.constant.device "cpu" - %false_4277 = torch.constant.bool false - %3561 = torch.aten.empty_strided %3559, %3560, %int4_4274, %int0_4275, %cpu_4276, %false_4277 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3561, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int15_4278 = torch.constant.int 15 - %3562 = torch.aten.fill.Scalar %3561, %int15_4278 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3562, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_4279 = torch.constant.int 32 - %3563 = torch.aten.mul.Scalar %arg3, %int32_4279 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3563, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_4280 = torch.constant.int 1 - %3564 = torch.aten.add.Tensor %3563, %3562, %int1_4280 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3564, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_4281 = torch.constant.int 4 - %3565 = torch.aten.mul.int %int4_4281, %358 : !torch.int, !torch.int -> !torch.int - %3566 = torch.prim.ListConstruct %3565 : (!torch.int) -> !torch.list - %3567 = torch.aten.view %3564, %3566 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3567, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_4282 = torch.constant.int 32 - %int2_4283 = torch.constant.int 2 - %int32_4284 = torch.constant.int 32 - %int8_4285 = torch.constant.int 8 - %int128_4286 = torch.constant.int 128 - %3568 = torch.prim.ListConstruct %437, %int32_4282, %int2_4283, %int32_4284, %int8_4285, %int128_4286 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3569 = torch.aten.view %3558, %3568 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3569, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4287 = torch.constant.int 32 - %3570 = torch.aten.mul.int %437, %int32_4287 : !torch.int, !torch.int -> !torch.int - %int2_4288 = torch.constant.int 2 - %int32_4289 = torch.constant.int 32 - %int8_4290 = torch.constant.int 8 - %int128_4291 = torch.constant.int 128 - %3571 = torch.prim.ListConstruct %3570, %int2_4288, %int32_4289, %int8_4290, %int128_4291 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3572 = torch.aten.view %3569, %3571 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %3572, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_4292 = torch.constant.int 0 - %3573 = torch.aten.index_select %3572, %int0_4292, %3567 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %3573, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_4293 = torch.constant.int 4 - %int2_4294 = torch.constant.int 2 - %int32_4295 = torch.constant.int 32 - %int8_4296 = torch.constant.int 8 - %int128_4297 = torch.constant.int 128 - %3574 = torch.prim.ListConstruct %int4_4293, %358, %int2_4294, %int32_4295, %int8_4296, %int128_4297 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3575 = torch.aten.view %3573, %3574 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3575, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_4298 = torch.constant.int 0 - %int0_4299 = torch.constant.int 0 - %int9223372036854775807_4300 = torch.constant.int 9223372036854775807 - %int1_4301 = torch.constant.int 1 - %3576 = torch.aten.slice.Tensor %3575, %int0_4298, %int0_4299, %int9223372036854775807_4300, %int1_4301 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3576, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_4302 = torch.constant.int 1 - %int0_4303 = torch.constant.int 0 - %int9223372036854775807_4304 = torch.constant.int 9223372036854775807 - %int1_4305 = torch.constant.int 1 - %3577 = torch.aten.slice.Tensor %3576, %int1_4302, %int0_4303, %int9223372036854775807_4304, %int1_4305 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3577, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_4306 = torch.constant.int 2 - %int0_4307 = torch.constant.int 0 - %3578 = torch.aten.select.int %3577, %int2_4306, %int0_4307 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3578, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_4308 = torch.constant.int 32 - %3579 = torch.aten.mul.int %358, %int32_4308 : !torch.int, !torch.int -> !torch.int - %int2_4309 = torch.constant.int 2 + %int4096_4275 = torch.constant.int 4096 + %4171 = torch.prim.ListConstruct %int4_4274, %int4096_4275 : (!torch.int, !torch.int) -> !torch.list + %4172 = torch.aten.view %4168, %4171 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4173 = torch.aten.mm %4172, %4170 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_4276 = torch.constant.int 4 + %int1_4277 = torch.constant.int 1 + %int4096_4278 = torch.constant.int 4096 + %4174 = torch.prim.ListConstruct %int4_4276, %int1_4277, %int4096_4278 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4175 = torch.aten.view %4173, %4174 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_4279 = torch.constant.int -2 + %int-1_4280 = torch.constant.int -1 + %4176 = torch.aten.transpose.int %242, %int-2_4279, %int-1_4280 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_4281 = torch.constant.int 5 + %4177 = torch.prims.convert_element_type %4176, %int5_4281 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_4282 = torch.constant.int 4 + %int4096_4283 = torch.constant.int 4096 + %4178 = torch.prim.ListConstruct %int4_4282, %int4096_4283 : (!torch.int, !torch.int) -> !torch.list + %4179 = torch.aten.view %4168, %4178 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4180 = torch.aten.mm %4179, %4177 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_4284 = torch.constant.int 4 + %int1_4285 = torch.constant.int 1 + %int1024_4286 = torch.constant.int 1024 + %4181 = torch.prim.ListConstruct %int4_4284, %int1_4285, %int1024_4286 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4182 = torch.aten.view %4180, %4181 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_4287 = torch.constant.int -2 + %int-1_4288 = torch.constant.int -1 + %4183 = torch.aten.transpose.int %243, %int-2_4287, %int-1_4288 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_4289 = torch.constant.int 5 + %4184 = torch.prims.convert_element_type %4183, %int5_4289 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_4290 = torch.constant.int 4 + %int4096_4291 = torch.constant.int 4096 + %4185 = torch.prim.ListConstruct %int4_4290, %int4096_4291 : (!torch.int, !torch.int) -> !torch.list + %4186 = torch.aten.view %4168, %4185 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4187 = torch.aten.mm %4186, %4184 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_4292 = torch.constant.int 4 + %int1_4293 = torch.constant.int 1 + %int1024_4294 = torch.constant.int 1024 + %4188 = torch.prim.ListConstruct %int4_4292, %int1_4293, %int1024_4294 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4189 = torch.aten.view %4187, %4188 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_4295 = torch.constant.int 4 + %int1_4296 = torch.constant.int 1 + %int32_4297 = torch.constant.int 32 + %int128_4298 = torch.constant.int 128 + %4190 = torch.prim.ListConstruct %int4_4295, %int1_4296, %int32_4297, %int128_4298 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4191 = torch.aten.view %4175, %4190 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_4299 = torch.constant.int 4 + %int1_4300 = torch.constant.int 1 + %int8_4301 = torch.constant.int 8 + %int128_4302 = torch.constant.int 128 + %4192 = torch.prim.ListConstruct %int4_4299, %int1_4300, %int8_4301, %int128_4302 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4193 = torch.aten.view %4182, %4192 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_4303 = torch.constant.int 4 + %int1_4304 = torch.constant.int 1 + %int8_4305 = torch.constant.int 8 + %int128_4306 = torch.constant.int 128 + %4194 = torch.prim.ListConstruct %int4_4303, %int1_4304, %int8_4305, %int128_4306 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4195 = torch.aten.view %4189, %4194 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_4307 = torch.constant.int 1 + %int2_4308 = torch.constant.int 2 + %4196 = torch.aten.transpose.int %4191, %int1_4307, %int2_4308 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %4197 = torch.aten.mul.Tensor %4196, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_4309 = torch.constant.int 3 %int0_4310 = torch.constant.int 0 - %int1_4311 = torch.constant.int 1 - %3580 = torch.aten.slice.Tensor %3578, %int2_4309, %int0_4310, %3579, %int1_4311 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3580, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_4312 = torch.constant.int 0 - %3581 = torch.aten.clone %3580, %int0_4312 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3581, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_4313 = torch.constant.int 1 - %3582 = torch.aten.size.int %3577, %int1_4313 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_4314 = torch.constant.int 32 - %3583 = torch.aten.mul.int %3582, %int32_4314 : !torch.int, !torch.int -> !torch.int - %int4_4315 = torch.constant.int 4 - %int8_4316 = torch.constant.int 8 - %int128_4317 = torch.constant.int 128 - %3584 = torch.prim.ListConstruct %int4_4315, %3583, %int8_4316, %int128_4317 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3585 = torch.aten._unsafe_view %3581, %3584 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3585, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_4318 = torch.constant.int 0 - %int0_4319 = torch.constant.int 0 - %int9223372036854775807_4320 = torch.constant.int 9223372036854775807 + %int64_4311 = torch.constant.int 64 + %int1_4312 = torch.constant.int 1 + %4198 = torch.aten.slice.Tensor %4196, %int3_4309, %int0_4310, %int64_4311, %int1_4312 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_4313 = torch.constant.int 3 + %int64_4314 = torch.constant.int 64 + %int9223372036854775807_4315 = torch.constant.int 9223372036854775807 + %int1_4316 = torch.constant.int 1 + %4199 = torch.aten.slice.Tensor %4196, %int3_4313, %int64_4314, %int9223372036854775807_4315, %int1_4316 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %4200 = torch.aten.neg %4199 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %4201 = torch.prim.ListConstruct %4200, %4198 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_4317 = torch.constant.int -1 + %4202 = torch.aten.cat %4201, %int-1_4317 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %4203 = torch.aten.mul.Tensor %4202, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_4318 = torch.constant.int 1 + %4204 = torch.aten.add.Tensor %4197, %4203, %int1_4318 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_4319 = torch.constant.int 1 + %int2_4320 = torch.constant.int 2 + %4205 = torch.aten.transpose.int %4204, %int1_4319, %int2_4320 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> %int1_4321 = torch.constant.int 1 - %3586 = torch.aten.slice.Tensor %3585, %int0_4318, %int0_4319, %int9223372036854775807_4320, %int1_4321 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3586, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_4322 = torch.constant.int 0 - %int0_4323 = torch.constant.int 0 - %int9223372036854775807_4324 = torch.constant.int 9223372036854775807 - %int1_4325 = torch.constant.int 1 - %3587 = torch.aten.slice.Tensor %3575, %int0_4322, %int0_4323, %int9223372036854775807_4324, %int1_4325 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3587, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> + %int2_4322 = torch.constant.int 2 + %4206 = torch.aten.transpose.int %4193, %int1_4321, %int2_4322 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %4207 = torch.aten.mul.Tensor %4206, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_4323 = torch.constant.int 3 + %int0_4324 = torch.constant.int 0 + %int64_4325 = torch.constant.int 64 %int1_4326 = torch.constant.int 1 - %int0_4327 = torch.constant.int 0 - %int9223372036854775807_4328 = torch.constant.int 9223372036854775807 - %int1_4329 = torch.constant.int 1 - %3588 = torch.aten.slice.Tensor %3587, %int1_4326, %int0_4327, %int9223372036854775807_4328, %int1_4329 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3588, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_4330 = torch.constant.int 2 - %int1_4331 = torch.constant.int 1 - %3589 = torch.aten.select.int %3588, %int2_4330, %int1_4331 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3589, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_4332 = torch.constant.int 2 - %int0_4333 = torch.constant.int 0 - %int1_4334 = torch.constant.int 1 - %3590 = torch.aten.slice.Tensor %3589, %int2_4332, %int0_4333, %3579, %int1_4334 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3590, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_4335 = torch.constant.int 0 - %3591 = torch.aten.clone %3590, %int0_4335 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3591, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %4208 = torch.aten.slice.Tensor %4206, %int3_4323, %int0_4324, %int64_4325, %int1_4326 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_4327 = torch.constant.int 3 + %int64_4328 = torch.constant.int 64 + %int9223372036854775807_4329 = torch.constant.int 9223372036854775807 + %int1_4330 = torch.constant.int 1 + %4209 = torch.aten.slice.Tensor %4206, %int3_4327, %int64_4328, %int9223372036854775807_4329, %int1_4330 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %4210 = torch.aten.neg %4209 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %4211 = torch.prim.ListConstruct %4210, %4208 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_4331 = torch.constant.int -1 + %4212 = torch.aten.cat %4211, %int-1_4331 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %4213 = torch.aten.mul.Tensor %4212, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_4332 = torch.constant.int 1 + %4214 = torch.aten.add.Tensor %4207, %4213, %int1_4332 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_4333 = torch.constant.int 1 + %int2_4334 = torch.constant.int 2 + %4215 = torch.aten.transpose.int %4214, %int1_4333, %int2_4334 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_4335 = torch.constant.int 32 + %4216 = torch.aten.floor_divide.Scalar %arg2, %int32_4335 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> %int1_4336 = torch.constant.int 1 - %3592 = torch.aten.size.int %3588, %int1_4336 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_4337 = torch.constant.int 32 - %3593 = torch.aten.mul.int %3592, %int32_4337 : !torch.int, !torch.int -> !torch.int - %int4_4338 = torch.constant.int 4 - %int8_4339 = torch.constant.int 8 - %int128_4340 = torch.constant.int 128 - %3594 = torch.prim.ListConstruct %int4_4338, %3593, %int8_4339, %int128_4340 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3595 = torch.aten._unsafe_view %3591, %3594 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3595, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_4341 = torch.constant.int 0 - %int0_4342 = torch.constant.int 0 - %int9223372036854775807_4343 = torch.constant.int 9223372036854775807 + %4217 = torch.aten.unsqueeze %4216, %int1_4336 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_4337 = torch.constant.int 1 + %false_4338 = torch.constant.bool false + %4218 = torch.aten.gather %arg3, %int1_4337, %4217, %false_4338 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_4339 = torch.constant.int 4 + %int1_4340 = torch.constant.int 1 + %int1_4341 = torch.constant.int 1 + %4219 = torch.prim.ListConstruct %int4_4339, %int1_4340, %int1_4341 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4220 = torch.aten.view %4218, %4219 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_4342 = torch.constant.int 32 + %4221 = torch.aten.remainder.Scalar %arg2, %int32_4342 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_4343 = torch.constant.int 4 %int1_4344 = torch.constant.int 1 - %3596 = torch.aten.slice.Tensor %3595, %int0_4341, %int0_4342, %int9223372036854775807_4343, %int1_4344 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3596, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_4345 = torch.constant.int -2 - %3597 = torch.aten.unsqueeze %3586, %int-2_4345 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3597, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_4346 = torch.constant.int 1 - %3598 = torch.aten.size.int %3585, %int1_4346 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_4347 = torch.constant.int 4 - %int8_4348 = torch.constant.int 8 - %int4_4349 = torch.constant.int 4 - %int128_4350 = torch.constant.int 128 - %3599 = torch.prim.ListConstruct %int4_4347, %3598, %int8_4348, %int4_4349, %int128_4350 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4351 = torch.constant.bool false - %3600 = torch.aten.expand %3597, %3599, %false_4351 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3600, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_4352 = torch.constant.int 0 - %3601 = torch.aten.clone %3600, %int0_4352 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3601, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4353 = torch.constant.int 4 - %int32_4354 = torch.constant.int 32 - %int128_4355 = torch.constant.int 128 - %3602 = torch.prim.ListConstruct %int4_4353, %3598, %int32_4354, %int128_4355 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3603 = torch.aten._unsafe_view %3601, %3602 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3603, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_4356 = torch.constant.int -2 - %3604 = torch.aten.unsqueeze %3596, %int-2_4356 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3604, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int1_4345 = torch.constant.int 1 + %4222 = torch.prim.ListConstruct %int4_4343, %int1_4344, %int1_4345 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4223 = torch.aten.view %4221, %4222 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_4346 = torch.constant.int 8 + %none_4347 = torch.constant.none + %none_4348 = torch.constant.none + %cpu_4349 = torch.constant.device "cpu" + %false_4350 = torch.constant.bool false + %4224 = torch.aten.arange %int8_4346, %none_4347, %none_4348, %cpu_4349, %false_4350 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_4351 = torch.constant.int 1 + %int1_4352 = torch.constant.int 1 + %int8_4353 = torch.constant.int 8 + %4225 = torch.prim.ListConstruct %int1_4351, %int1_4352, %int8_4353 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4226 = torch.aten.view %4224, %4225 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_4354 = torch.constant.none + %4227 = torch.aten.clone %244, %none_4354 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4228 = torch.aten.detach %4227 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4229 = torch.aten.detach %4228 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4230 = torch.aten.detach %4229 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_4355 = torch.constant.int 1 + %int1_4356 = torch.constant.int 1 %int1_4357 = torch.constant.int 1 - %3605 = torch.aten.size.int %3595, %int1_4357 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_4358 = torch.constant.int 4 - %int8_4359 = torch.constant.int 8 - %int4_4360 = torch.constant.int 4 - %int128_4361 = torch.constant.int 128 - %3606 = torch.prim.ListConstruct %int4_4358, %3605, %int8_4359, %int4_4360, %int128_4361 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4362 = torch.constant.bool false - %3607 = torch.aten.expand %3604, %3606, %false_4362 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3607, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_4363 = torch.constant.int 0 - %3608 = torch.aten.clone %3607, %int0_4363 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3608, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4364 = torch.constant.int 4 - %int32_4365 = torch.constant.int 32 - %int128_4366 = torch.constant.int 128 - %3609 = torch.prim.ListConstruct %int4_4364, %3605, %int32_4365, %int128_4366 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3610 = torch.aten._unsafe_view %3608, %3609 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3610, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_4367 = torch.constant.int 1 + %4231 = torch.prim.ListConstruct %int1_4355, %int1_4356, %int1_4357 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4232 = torch.aten.view %4230, %4231 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_4358 = torch.constant.int 32 + %4233 = torch.aten.mul.Scalar %4220, %int32_4358 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int17 = torch.constant.int 17 + %int1_4359 = torch.constant.int 1 + %4234 = torch.aten.add.Scalar %4233, %int17, %int1_4359 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_4360 = torch.constant.int 2 + %4235 = torch.aten.mul.Scalar %4234, %int2_4360 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4361 = torch.constant.int 1 + %4236 = torch.aten.add.Tensor %4235, %4232, %int1_4361 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_4362 = torch.constant.int 8 + %4237 = torch.aten.mul.Scalar %4236, %int8_4362 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4363 = torch.constant.int 1 + %4238 = torch.aten.add.Tensor %4237, %4226, %int1_4363 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_4364 = torch.constant.int 32 + %4239 = torch.aten.mul.Scalar %4238, %int32_4364 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_4365 = torch.constant.int 1 + %4240 = torch.aten.add.Tensor %4239, %4223, %int1_4365 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_4366 = torch.constant.int 5 + %4241 = torch.prims.convert_element_type %4215, %int5_4366 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_4367 = torch.constant.int 32 %int2_4368 = torch.constant.int 2 - %3611 = torch.aten.transpose.int %3491, %int1_4367, %int2_4368 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_4369 = torch.constant.int 1 - %int2_4370 = torch.constant.int 2 - %3612 = torch.aten.transpose.int %3603, %int1_4369, %int2_4370 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3612, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_4371 = torch.constant.int 1 - %int2_4372 = torch.constant.int 2 - %3613 = torch.aten.transpose.int %3610, %int1_4371, %int2_4372 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3613, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_4373 = torch.constant.float 0.000000e+00 - %false_4374 = torch.constant.bool false - %none_4375 = torch.constant.none - %3614:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3611, %3612, %3613, %float0.000000e00_4373, %false_4374, %368, %none_4375) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_4376 = torch.constant.int 1 - %int2_4377 = torch.constant.int 2 - %3615 = torch.aten.transpose.int %3614#0, %int1_4376, %int2_4377 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_4378 = torch.constant.int 4 - %int1_4379 = torch.constant.int 1 - %int4096_4380 = torch.constant.int 4096 - %3616 = torch.prim.ListConstruct %int4_4378, %int1_4379, %int4096_4380 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3617 = torch.aten.view %3615, %3616 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_4381 = torch.constant.int -2 - %int-1_4382 = torch.constant.int -1 - %3618 = torch.aten.transpose.int %172, %int-2_4381, %int-1_4382 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_4383 = torch.constant.int 4 - %int4096_4384 = torch.constant.int 4096 - %3619 = torch.prim.ListConstruct %int4_4383, %int4096_4384 : (!torch.int, !torch.int) -> !torch.list - %3620 = torch.aten.view %3617, %3619 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3621 = torch.aten.mm %3620, %3618 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_4385 = torch.constant.int 4 - %int1_4386 = torch.constant.int 1 - %int4096_4387 = torch.constant.int 4096 - %3622 = torch.prim.ListConstruct %int4_4385, %int1_4386, %int4096_4387 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3623 = torch.aten.view %3621, %3622 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int8_4369 = torch.constant.int 8 + %int32_4370 = torch.constant.int 32 + %int128_4371 = torch.constant.int 128 + %4242 = torch.prim.ListConstruct %456, %int32_4367, %int2_4368, %int8_4369, %int32_4370, %int128_4371 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4243 = torch.aten.view %4063, %4242 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4243, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_4372 = torch.constant.int 128 + %4244 = torch.prim.ListConstruct %596, %int128_4372 : (!torch.int, !torch.int) -> !torch.list + %4245 = torch.aten.view %4243, %4244 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4245, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %4246 = torch.prim.ListConstruct %4240 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_4373 = torch.constant.bool false + %4247 = torch.aten.index_put %4245, %4246, %4241, %false_4373 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4247, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_4374 = torch.constant.int 32 + %int2_4375 = torch.constant.int 2 + %int8_4376 = torch.constant.int 8 + %int32_4377 = torch.constant.int 32 + %int128_4378 = torch.constant.int 128 + %4248 = torch.prim.ListConstruct %456, %int32_4374, %int2_4375, %int8_4376, %int32_4377, %int128_4378 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4249 = torch.aten.view %4247, %4248 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4249, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_4379 = torch.constant.int 2097152 + %4250 = torch.prim.ListConstruct %456, %int2097152_4379 : (!torch.int, !torch.int) -> !torch.list + %4251 = torch.aten.view %4249, %4250 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4251, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_4380 = torch.constant.int 32 + %int2_4381 = torch.constant.int 2 + %int8_4382 = torch.constant.int 8 + %int32_4383 = torch.constant.int 32 + %int128_4384 = torch.constant.int 128 + %4252 = torch.prim.ListConstruct %456, %int32_4380, %int2_4381, %int8_4382, %int32_4383, %int128_4384 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4253 = torch.aten.view %4251, %4252 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4253, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_4385 = torch.constant.int 128 + %4254 = torch.prim.ListConstruct %596, %int128_4385 : (!torch.int, !torch.int) -> !torch.list + %4255 = torch.aten.view %4253, %4254 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4255, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_4386 = torch.constant.none + %4256 = torch.aten.clone %245, %none_4386 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4257 = torch.aten.detach %4256 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4258 = torch.aten.detach %4257 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4259 = torch.aten.detach %4258 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_4387 = torch.constant.int 1 %int1_4388 = torch.constant.int 1 - %3624 = torch.aten.add.Tensor %3451, %3623, %int1_4388 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_4389 = torch.constant.int 6 - %3625 = torch.prims.convert_element_type %3624, %int6_4389 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_4390 = torch.constant.int 2 - %3626 = torch.aten.pow.Tensor_Scalar %3625, %int2_4390 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_4391 = torch.constant.int -1 - %3627 = torch.prim.ListConstruct %int-1_4391 : (!torch.int) -> !torch.list - %true_4392 = torch.constant.bool true - %none_4393 = torch.constant.none - %3628 = torch.aten.mean.dim %3626, %3627, %true_4392, %none_4393 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_4394 = torch.constant.float 9.9999997473787516E-6 - %int1_4395 = torch.constant.int 1 - %3629 = torch.aten.add.Scalar %3628, %float9.999990e-06_4394, %int1_4395 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %3630 = torch.aten.rsqrt %3629 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %3631 = torch.aten.mul.Tensor %3625, %3630 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_4396 = torch.constant.int 5 - %3632 = torch.prims.convert_element_type %3631, %int5_4396 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %3633 = torch.aten.mul.Tensor %173, %3632 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_4397 = torch.constant.int 5 - %3634 = torch.prims.convert_element_type %3633, %int5_4397 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_4398 = torch.constant.int -2 - %int-1_4399 = torch.constant.int -1 - %3635 = torch.aten.transpose.int %174, %int-2_4398, %int-1_4399 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_4400 = torch.constant.int 4 - %int4096_4401 = torch.constant.int 4096 - %3636 = torch.prim.ListConstruct %int4_4400, %int4096_4401 : (!torch.int, !torch.int) -> !torch.list - %3637 = torch.aten.view %3634, %3636 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3638 = torch.aten.mm %3637, %3635 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_4402 = torch.constant.int 4 - %int1_4403 = torch.constant.int 1 - %int14336_4404 = torch.constant.int 14336 - %3639 = torch.prim.ListConstruct %int4_4402, %int1_4403, %int14336_4404 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3640 = torch.aten.view %3638, %3639 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %3641 = torch.aten.silu %3640 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_4405 = torch.constant.int -2 - %int-1_4406 = torch.constant.int -1 - %3642 = torch.aten.transpose.int %175, %int-2_4405, %int-1_4406 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_4407 = torch.constant.int 4 - %int4096_4408 = torch.constant.int 4096 - %3643 = torch.prim.ListConstruct %int4_4407, %int4096_4408 : (!torch.int, !torch.int) -> !torch.list - %3644 = torch.aten.view %3634, %3643 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3645 = torch.aten.mm %3644, %3642 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_4409 = torch.constant.int 4 - %int1_4410 = torch.constant.int 1 - %int14336_4411 = torch.constant.int 14336 - %3646 = torch.prim.ListConstruct %int4_4409, %int1_4410, %int14336_4411 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3647 = torch.aten.view %3645, %3646 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %3648 = torch.aten.mul.Tensor %3641, %3647 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_4412 = torch.constant.int -2 - %int-1_4413 = torch.constant.int -1 - %3649 = torch.aten.transpose.int %176, %int-2_4412, %int-1_4413 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_4414 = torch.constant.int 4 - %int14336_4415 = torch.constant.int 14336 - %3650 = torch.prim.ListConstruct %int4_4414, %int14336_4415 : (!torch.int, !torch.int) -> !torch.list - %3651 = torch.aten.view %3648, %3650 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %3652 = torch.aten.mm %3651, %3649 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_4416 = torch.constant.int 4 - %int1_4417 = torch.constant.int 1 - %int4096_4418 = torch.constant.int 4096 - %3653 = torch.prim.ListConstruct %int4_4416, %int1_4417, %int4096_4418 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3654 = torch.aten.view %3652, %3653 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_4419 = torch.constant.int 1 - %3655 = torch.aten.add.Tensor %3624, %3654, %int1_4419 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_4420 = torch.constant.int 6 - %3656 = torch.prims.convert_element_type %3655, %int6_4420 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_4421 = torch.constant.int 2 - %3657 = torch.aten.pow.Tensor_Scalar %3656, %int2_4421 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_4422 = torch.constant.int -1 - %3658 = torch.prim.ListConstruct %int-1_4422 : (!torch.int) -> !torch.list - %true_4423 = torch.constant.bool true - %none_4424 = torch.constant.none - %3659 = torch.aten.mean.dim %3657, %3658, %true_4423, %none_4424 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_4425 = torch.constant.float 9.9999997473787516E-6 - %int1_4426 = torch.constant.int 1 - %3660 = torch.aten.add.Scalar %3659, %float9.999990e-06_4425, %int1_4426 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %3661 = torch.aten.rsqrt %3660 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %3662 = torch.aten.mul.Tensor %3656, %3661 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_4427 = torch.constant.int 5 - %3663 = torch.prims.convert_element_type %3662, %int5_4427 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %3664 = torch.aten.mul.Tensor %177, %3663 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_4428 = torch.constant.int 5 - %3665 = torch.prims.convert_element_type %3664, %int5_4428 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_4429 = torch.constant.int -2 - %int-1_4430 = torch.constant.int -1 - %3666 = torch.aten.transpose.int %178, %int-2_4429, %int-1_4430 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_4431 = torch.constant.int 4 - %int4096_4432 = torch.constant.int 4096 - %3667 = torch.prim.ListConstruct %int4_4431, %int4096_4432 : (!torch.int, !torch.int) -> !torch.list - %3668 = torch.aten.view %3665, %3667 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3669 = torch.aten.mm %3668, %3666 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_4433 = torch.constant.int 4 - %int1_4434 = torch.constant.int 1 - %int4096_4435 = torch.constant.int 4096 - %3670 = torch.prim.ListConstruct %int4_4433, %int1_4434, %int4096_4435 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3671 = torch.aten.view %3669, %3670 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_4436 = torch.constant.int -2 - %int-1_4437 = torch.constant.int -1 - %3672 = torch.aten.transpose.int %179, %int-2_4436, %int-1_4437 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int1_4389 = torch.constant.int 1 + %4260 = torch.prim.ListConstruct %int1_4387, %int1_4388, %int1_4389 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4261 = torch.aten.view %4259, %4260 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_4390 = torch.constant.int 32 + %4262 = torch.aten.mul.Scalar %4220, %int32_4390 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int17_4391 = torch.constant.int 17 + %int1_4392 = torch.constant.int 1 + %4263 = torch.aten.add.Scalar %4262, %int17_4391, %int1_4392 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_4393 = torch.constant.int 2 + %4264 = torch.aten.mul.Scalar %4263, %int2_4393 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4394 = torch.constant.int 1 + %4265 = torch.aten.add.Tensor %4264, %4261, %int1_4394 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_4395 = torch.constant.int 8 + %4266 = torch.aten.mul.Scalar %4265, %int8_4395 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4396 = torch.constant.int 1 + %4267 = torch.aten.add.Tensor %4266, %4226, %int1_4396 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_4397 = torch.constant.int 32 + %4268 = torch.aten.mul.Scalar %4267, %int32_4397 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_4398 = torch.constant.int 1 + %4269 = torch.aten.add.Tensor %4268, %4223, %int1_4398 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_4399 = torch.constant.int 5 + %4270 = torch.prims.convert_element_type %4195, %int5_4399 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %4271 = torch.prim.ListConstruct %4269 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_4400 = torch.constant.bool false + %4272 = torch.aten.index_put %4255, %4271, %4270, %false_4400 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4272, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_4401 = torch.constant.int 32 + %int2_4402 = torch.constant.int 2 + %int8_4403 = torch.constant.int 8 + %int32_4404 = torch.constant.int 32 + %int128_4405 = torch.constant.int 128 + %4273 = torch.prim.ListConstruct %456, %int32_4401, %int2_4402, %int8_4403, %int32_4404, %int128_4405 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4274 = torch.aten.view %4272, %4273 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4274, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_4406 = torch.constant.int 2097152 + %4275 = torch.prim.ListConstruct %456, %int2097152_4406 : (!torch.int, !torch.int) -> !torch.list + %4276 = torch.aten.view %4274, %4275 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4276, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_4407 = torch.constant.none + %4277 = torch.aten.clone %246, %none_4407 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4278 = torch.aten.detach %4277 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4279 = torch.aten.detach %4278 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4280 = torch.aten.detach %4279 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_4408 = torch.constant.none + %4281 = torch.aten.clone %247, %none_4408 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4282 = torch.aten.detach %4281 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4283 = torch.aten.detach %4282 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4284 = torch.aten.detach %4283 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_4409 = torch.constant.none + %4285 = torch.aten.clone %248, %none_4409 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4286 = torch.aten.detach %4285 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4287 = torch.aten.detach %4286 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4288 = torch.aten.detach %4287 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_4410 = torch.constant.int 32 + %int2_4411 = torch.constant.int 2 + %int8_4412 = torch.constant.int 8 + %int32_4413 = torch.constant.int 32 + %int128_4414 = torch.constant.int 128 + %4289 = torch.prim.ListConstruct %456, %int32_4410, %int2_4411, %int8_4412, %int32_4413, %int128_4414 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4290 = torch.aten.view %4276, %4289 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4290, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %4291 = torch_c.to_builtin_tensor %4290 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %4292 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_4415 = tensor.cast %4292 : tensor<4x?xi64> to tensor + %4293 = torch_c.to_builtin_tensor %4280 : !torch.vtensor<[],si64> -> tensor + %4294 = torch_c.to_builtin_tensor %4284 : !torch.vtensor<[],si64> -> tensor + %4295 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%4291, %cast_4415, %4293, %4294) : (tensor, tensor, tensor, tensor) -> tensor + %cast_4416 = tensor.cast %4295 : tensor to tensor<4x?x8x32x128xf16> + %4296 = torch_c.from_builtin_tensor %cast_4416 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %4296, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %4297 = torch_c.to_builtin_tensor %4290 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %4298 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_4417 = tensor.cast %4298 : tensor<4x?xi64> to tensor + %4299 = torch_c.to_builtin_tensor %4280 : !torch.vtensor<[],si64> -> tensor + %4300 = torch_c.to_builtin_tensor %4288 : !torch.vtensor<[],si64> -> tensor + %4301 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%4297, %cast_4417, %4299, %4300) : (tensor, tensor, tensor, tensor) -> tensor + %cast_4418 = tensor.cast %4301 : tensor to tensor<4x?x8x32x128xf16> + %4302 = torch_c.from_builtin_tensor %cast_4418 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %4302, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_4419 = torch.constant.int 2 + %int3_4420 = torch.constant.int 3 + %4303 = torch.aten.transpose.int %4296, %int2_4419, %int3_4420 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4303, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_4421 = torch.constant.int 0 + %4304 = torch.aten.clone %4303, %int0_4421 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4304, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_4422 = torch.constant.int 4 + %int8_4423 = torch.constant.int 8 + %int128_4424 = torch.constant.int 128 + %4305 = torch.prim.ListConstruct %int4_4422, %457, %int8_4423, %int128_4424 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4306 = torch.aten._unsafe_view %4304, %4305 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4306, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_4425 = torch.constant.int 2 + %int3_4426 = torch.constant.int 3 + %4307 = torch.aten.transpose.int %4302, %int2_4425, %int3_4426 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4307, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_4427 = torch.constant.int 0 + %4308 = torch.aten.clone %4307, %int0_4427 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4308, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_4428 = torch.constant.int 4 + %int8_4429 = torch.constant.int 8 + %int128_4430 = torch.constant.int 128 + %4309 = torch.prim.ListConstruct %int4_4428, %457, %int8_4429, %int128_4430 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4310 = torch.aten._unsafe_view %4308, %4309 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4310, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_4431 = torch.constant.int -2 + %4311 = torch.aten.unsqueeze %4306, %int-2_4431 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4311, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_4432 = torch.constant.int 4 + %int8_4433 = torch.constant.int 8 + %int4_4434 = torch.constant.int 4 + %int128_4435 = torch.constant.int 128 + %4312 = torch.prim.ListConstruct %int4_4432, %457, %int8_4433, %int4_4434, %int128_4435 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_4436 = torch.constant.bool false + %4313 = torch.aten.expand %4311, %4312, %false_4436 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4313, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_4437 = torch.constant.int 0 + %4314 = torch.aten.clone %4313, %int0_4437 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4314, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_4438 = torch.constant.int 4 - %int4096_4439 = torch.constant.int 4096 - %3673 = torch.prim.ListConstruct %int4_4438, %int4096_4439 : (!torch.int, !torch.int) -> !torch.list - %3674 = torch.aten.view %3665, %3673 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3675 = torch.aten.mm %3674, %3672 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_4440 = torch.constant.int 4 - %int1_4441 = torch.constant.int 1 - %int1024_4442 = torch.constant.int 1024 - %3676 = torch.prim.ListConstruct %int4_4440, %int1_4441, %int1024_4442 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3677 = torch.aten.view %3675, %3676 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_4443 = torch.constant.int -2 - %int-1_4444 = torch.constant.int -1 - %3678 = torch.aten.transpose.int %180, %int-2_4443, %int-1_4444 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_4445 = torch.constant.int 4 - %int4096_4446 = torch.constant.int 4096 - %3679 = torch.prim.ListConstruct %int4_4445, %int4096_4446 : (!torch.int, !torch.int) -> !torch.list - %3680 = torch.aten.view %3665, %3679 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3681 = torch.aten.mm %3680, %3678 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_4447 = torch.constant.int 4 - %int1_4448 = torch.constant.int 1 - %int1024_4449 = torch.constant.int 1024 - %3682 = torch.prim.ListConstruct %int4_4447, %int1_4448, %int1024_4449 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3683 = torch.aten.view %3681, %3682 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_4450 = torch.constant.int 4 + %int32_4439 = torch.constant.int 32 + %int128_4440 = torch.constant.int 128 + %4315 = torch.prim.ListConstruct %int4_4438, %457, %int32_4439, %int128_4440 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4316 = torch.aten._unsafe_view %4314, %4315 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4316, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_4441 = torch.constant.int -2 + %4317 = torch.aten.unsqueeze %4310, %int-2_4441 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4317, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_4442 = torch.constant.int 4 + %int8_4443 = torch.constant.int 8 + %int4_4444 = torch.constant.int 4 + %int128_4445 = torch.constant.int 128 + %4318 = torch.prim.ListConstruct %int4_4442, %457, %int8_4443, %int4_4444, %int128_4445 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_4446 = torch.constant.bool false + %4319 = torch.aten.expand %4317, %4318, %false_4446 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4319, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_4447 = torch.constant.int 0 + %4320 = torch.aten.clone %4319, %int0_4447 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4320, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_4448 = torch.constant.int 4 + %int32_4449 = torch.constant.int 32 + %int128_4450 = torch.constant.int 128 + %4321 = torch.prim.ListConstruct %int4_4448, %457, %int32_4449, %int128_4450 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4322 = torch.aten._unsafe_view %4320, %4321 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4322, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int1_4451 = torch.constant.int 1 - %int32_4452 = torch.constant.int 32 - %int128_4453 = torch.constant.int 128 - %3684 = torch.prim.ListConstruct %int4_4450, %int1_4451, %int32_4452, %int128_4453 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3685 = torch.aten.view %3671, %3684 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_4454 = torch.constant.int 4 + %int2_4452 = torch.constant.int 2 + %4323 = torch.aten.transpose.int %4205, %int1_4451, %int2_4452 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_4453 = torch.constant.int 1 + %int2_4454 = torch.constant.int 2 + %4324 = torch.aten.transpose.int %4316, %int1_4453, %int2_4454 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4324, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_4455 = torch.constant.int 1 - %int8_4456 = torch.constant.int 8 - %int128_4457 = torch.constant.int 128 - %3686 = torch.prim.ListConstruct %int4_4454, %int1_4455, %int8_4456, %int128_4457 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3687 = torch.aten.view %3677, %3686 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_4458 = torch.constant.int 4 - %int1_4459 = torch.constant.int 1 - %int8_4460 = torch.constant.int 8 - %int128_4461 = torch.constant.int 128 - %3688 = torch.prim.ListConstruct %int4_4458, %int1_4459, %int8_4460, %int128_4461 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3689 = torch.aten.view %3683, %3688 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_4462 = torch.constant.int 6 - %3690 = torch.prims.convert_element_type %3685, %int6_4462 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %3691 = torch_c.to_builtin_tensor %3690 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %3692 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %3693 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%3691, %3692) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %3694 = torch_c.from_builtin_tensor %3693 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_4463 = torch.constant.int 5 - %3695 = torch.prims.convert_element_type %3694, %int5_4463 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_4464 = torch.constant.int 6 - %3696 = torch.prims.convert_element_type %3687, %int6_4464 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %3697 = torch_c.to_builtin_tensor %3696 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %3698 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %3699 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%3697, %3698) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %3700 = torch_c.from_builtin_tensor %3699 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_4465 = torch.constant.int 5 - %3701 = torch.prims.convert_element_type %3700, %int5_4465 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_4466 = torch.constant.int 32 - %3702 = torch.aten.floor_divide.Scalar %arg2, %int32_4466 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_4467 = torch.constant.int 1 - %3703 = torch.aten.unsqueeze %3702, %int1_4467 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4468 = torch.constant.int 1 - %false_4469 = torch.constant.bool false - %3704 = torch.aten.gather %arg3, %int1_4468, %3703, %false_4469 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_4470 = torch.constant.int 32 - %3705 = torch.aten.remainder.Scalar %arg2, %int32_4470 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int2_4456 = torch.constant.int 2 + %4325 = torch.aten.transpose.int %4322, %int1_4455, %int2_4456 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4325, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_4457 = torch.constant.float 0.000000e+00 + %false_4458 = torch.constant.bool false + %none_4459 = torch.constant.none + %4326:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4323, %4324, %4325, %float0.000000e00_4457, %false_4458, %470, %none_4459) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_4460 = torch.constant.int 1 + %int2_4461 = torch.constant.int 2 + %4327 = torch.aten.transpose.int %4326#0, %int1_4460, %int2_4461 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_4462 = torch.constant.int 4 + %int1_4463 = torch.constant.int 1 + %int4096_4464 = torch.constant.int 4096 + %4328 = torch.prim.ListConstruct %int4_4462, %int1_4463, %int4096_4464 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4329 = torch.aten.view %4327, %4328 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_4465 = torch.constant.int -2 + %int-1_4466 = torch.constant.int -1 + %4330 = torch.aten.transpose.int %249, %int-2_4465, %int-1_4466 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_4467 = torch.constant.int 5 + %4331 = torch.prims.convert_element_type %4330, %int5_4467 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_4468 = torch.constant.int 4 + %int4096_4469 = torch.constant.int 4096 + %4332 = torch.prim.ListConstruct %int4_4468, %int4096_4469 : (!torch.int, !torch.int) -> !torch.list + %4333 = torch.aten.view %4329, %4332 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4334 = torch.aten.mm %4333, %4331 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_4470 = torch.constant.int 4 %int1_4471 = torch.constant.int 1 - %3706 = torch.aten.unsqueeze %3705, %int1_4471 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_4472 = torch.constant.none - %3707 = torch.aten.clone %181, %none_4472 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_4473 = torch.constant.int 0 - %3708 = torch.aten.unsqueeze %3707, %int0_4473 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_4474 = torch.constant.int 4 - %int1_4475 = torch.constant.int 1 - %3709 = torch.prim.ListConstruct %int4_4474, %int1_4475 : (!torch.int, !torch.int) -> !torch.list - %int1_4476 = torch.constant.int 1 - %int1_4477 = torch.constant.int 1 - %3710 = torch.prim.ListConstruct %int1_4476, %int1_4477 : (!torch.int, !torch.int) -> !torch.list - %int4_4478 = torch.constant.int 4 - %int0_4479 = torch.constant.int 0 - %cpu_4480 = torch.constant.device "cpu" - %false_4481 = torch.constant.bool false - %3711 = torch.aten.empty_strided %3709, %3710, %int4_4478, %int0_4479, %cpu_4480, %false_4481 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int16 = torch.constant.int 16 - %3712 = torch.aten.fill.Scalar %3711, %int16 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_4482 = torch.constant.int 4 - %int1_4483 = torch.constant.int 1 - %3713 = torch.prim.ListConstruct %int4_4482, %int1_4483 : (!torch.int, !torch.int) -> !torch.list - %3714 = torch.aten.repeat %3708, %3713 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_4484 = torch.constant.int 32 - %3715 = torch.aten.mul.Scalar %3704, %int32_4484 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4485 = torch.constant.int 1 - %3716 = torch.aten.add.Tensor %3715, %3712, %int1_4485 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_4486 = torch.constant.int 2 - %3717 = torch.aten.mul.Scalar %3716, %int2_4486 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4487 = torch.constant.int 1 - %3718 = torch.aten.add.Tensor %3717, %3714, %int1_4487 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_4488 = torch.constant.int 32 - %3719 = torch.aten.mul.Scalar %3718, %int32_4488 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int4096_4472 = torch.constant.int 4096 + %4335 = torch.prim.ListConstruct %int4_4470, %int1_4471, %int4096_4472 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4336 = torch.aten.view %4334, %4335 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_4473 = torch.constant.int 1 + %4337 = torch.aten.add.Tensor %4158, %4336, %int1_4473 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_4474 = torch.constant.int 6 + %4338 = torch.prims.convert_element_type %4337, %int6_4474 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_4475 = torch.constant.int 2 + %4339 = torch.aten.pow.Tensor_Scalar %4338, %int2_4475 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_4476 = torch.constant.int -1 + %4340 = torch.prim.ListConstruct %int-1_4476 : (!torch.int) -> !torch.list + %true_4477 = torch.constant.bool true + %none_4478 = torch.constant.none + %4341 = torch.aten.mean.dim %4339, %4340, %true_4477, %none_4478 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_4479 = torch.constant.float 9.9999997473787516E-6 + %int1_4480 = torch.constant.int 1 + %4342 = torch.aten.add.Scalar %4341, %float9.999990e-06_4479, %int1_4480 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %4343 = torch.aten.rsqrt %4342 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %4344 = torch.aten.mul.Tensor %4338, %4343 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_4481 = torch.constant.int 5 + %4345 = torch.prims.convert_element_type %4344, %int5_4481 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %4346 = torch.aten.mul.Tensor %250, %4345 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_4482 = torch.constant.int 5 + %4347 = torch.prims.convert_element_type %4346, %int5_4482 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_4483 = torch.constant.int -2 + %int-1_4484 = torch.constant.int -1 + %4348 = torch.aten.transpose.int %251, %int-2_4483, %int-1_4484 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_4485 = torch.constant.int 5 + %4349 = torch.prims.convert_element_type %4348, %int5_4485 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_4486 = torch.constant.int 4 + %int4096_4487 = torch.constant.int 4096 + %4350 = torch.prim.ListConstruct %int4_4486, %int4096_4487 : (!torch.int, !torch.int) -> !torch.list + %4351 = torch.aten.view %4347, %4350 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4352 = torch.aten.mm %4351, %4349 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_4488 = torch.constant.int 4 %int1_4489 = torch.constant.int 1 - %3720 = torch.aten.add.Tensor %3719, %3706, %int1_4489 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_4490 = torch.constant.int 32 - %int2_4491 = torch.constant.int 2 - %int32_4492 = torch.constant.int 32 - %int8_4493 = torch.constant.int 8 - %int128_4494 = torch.constant.int 128 - %3721 = torch.prim.ListConstruct %437, %int32_4490, %int2_4491, %int32_4492, %int8_4493, %int128_4494 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3722 = torch.aten.view %3558, %3721 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3722, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4495 = torch.constant.int 32 - %3723 = torch.aten.mul.int %437, %int32_4495 : !torch.int, !torch.int -> !torch.int - %int2_4496 = torch.constant.int 2 - %3724 = torch.aten.mul.int %3723, %int2_4496 : !torch.int, !torch.int -> !torch.int - %int32_4497 = torch.constant.int 32 - %3725 = torch.aten.mul.int %3724, %int32_4497 : !torch.int, !torch.int -> !torch.int - %int8_4498 = torch.constant.int 8 - %int128_4499 = torch.constant.int 128 - %3726 = torch.prim.ListConstruct %3725, %int8_4498, %int128_4499 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3727 = torch.aten.view %3722, %3726 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3727, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %3728 = torch.prim.ListConstruct %3720 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_4500 = torch.constant.bool false - %3729 = torch.aten.index_put %3727, %3728, %3701, %false_4500 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3729, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_4501 = torch.constant.int 32 - %int2_4502 = torch.constant.int 2 - %int32_4503 = torch.constant.int 32 - %int8_4504 = torch.constant.int 8 - %int128_4505 = torch.constant.int 128 - %3730 = torch.prim.ListConstruct %437, %int32_4501, %int2_4502, %int32_4503, %int8_4504, %int128_4505 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3731 = torch.aten.view %3729, %3730 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3731, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_4506 = torch.constant.int 2097152 - %3732 = torch.prim.ListConstruct %437, %int2097152_4506 : (!torch.int, !torch.int) -> !torch.list - %3733 = torch.aten.view %3731, %3732 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3733, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_4507 = torch.constant.int 32 - %int2_4508 = torch.constant.int 2 - %int32_4509 = torch.constant.int 32 - %int8_4510 = torch.constant.int 8 - %int128_4511 = torch.constant.int 128 - %3734 = torch.prim.ListConstruct %437, %int32_4507, %int2_4508, %int32_4509, %int8_4510, %int128_4511 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3735 = torch.aten.view %3733, %3734 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3735, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_4512 = torch.constant.int 8 - %int128_4513 = torch.constant.int 128 - %3736 = torch.prim.ListConstruct %3725, %int8_4512, %int128_4513 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3737 = torch.aten.view %3735, %3736 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3737, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_4514 = torch.constant.int 32 - %3738 = torch.aten.floor_divide.Scalar %arg2, %int32_4514 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_4515 = torch.constant.int 1 - %3739 = torch.aten.unsqueeze %3738, %int1_4515 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4516 = torch.constant.int 1 - %false_4517 = torch.constant.bool false - %3740 = torch.aten.gather %arg3, %int1_4516, %3739, %false_4517 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_4518 = torch.constant.int 32 - %3741 = torch.aten.remainder.Scalar %arg2, %int32_4518 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_4519 = torch.constant.int 1 - %3742 = torch.aten.unsqueeze %3741, %int1_4519 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_4520 = torch.constant.none - %3743 = torch.aten.clone %182, %none_4520 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_4521 = torch.constant.int 0 - %3744 = torch.aten.unsqueeze %3743, %int0_4521 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %int14336_4490 = torch.constant.int 14336 + %4353 = torch.prim.ListConstruct %int4_4488, %int1_4489, %int14336_4490 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4354 = torch.aten.view %4352, %4353 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %4355 = torch.aten.silu %4354 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_4491 = torch.constant.int -2 + %int-1_4492 = torch.constant.int -1 + %4356 = torch.aten.transpose.int %252, %int-2_4491, %int-1_4492 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_4493 = torch.constant.int 5 + %4357 = torch.prims.convert_element_type %4356, %int5_4493 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_4494 = torch.constant.int 4 + %int4096_4495 = torch.constant.int 4096 + %4358 = torch.prim.ListConstruct %int4_4494, %int4096_4495 : (!torch.int, !torch.int) -> !torch.list + %4359 = torch.aten.view %4347, %4358 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4360 = torch.aten.mm %4359, %4357 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_4496 = torch.constant.int 4 + %int1_4497 = torch.constant.int 1 + %int14336_4498 = torch.constant.int 14336 + %4361 = torch.prim.ListConstruct %int4_4496, %int1_4497, %int14336_4498 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4362 = torch.aten.view %4360, %4361 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %4363 = torch.aten.mul.Tensor %4355, %4362 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_4499 = torch.constant.int -2 + %int-1_4500 = torch.constant.int -1 + %4364 = torch.aten.transpose.int %253, %int-2_4499, %int-1_4500 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_4501 = torch.constant.int 5 + %4365 = torch.prims.convert_element_type %4364, %int5_4501 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_4502 = torch.constant.int 4 + %int14336_4503 = torch.constant.int 14336 + %4366 = torch.prim.ListConstruct %int4_4502, %int14336_4503 : (!torch.int, !torch.int) -> !torch.list + %4367 = torch.aten.view %4363, %4366 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %4368 = torch.aten.mm %4367, %4365 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_4504 = torch.constant.int 4 + %int1_4505 = torch.constant.int 1 + %int4096_4506 = torch.constant.int 4096 + %4369 = torch.prim.ListConstruct %int4_4504, %int1_4505, %int4096_4506 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4370 = torch.aten.view %4368, %4369 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_4507 = torch.constant.int 1 + %4371 = torch.aten.add.Tensor %4337, %4370, %int1_4507 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_4508 = torch.constant.int 6 + %4372 = torch.prims.convert_element_type %4371, %int6_4508 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_4509 = torch.constant.int 2 + %4373 = torch.aten.pow.Tensor_Scalar %4372, %int2_4509 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_4510 = torch.constant.int -1 + %4374 = torch.prim.ListConstruct %int-1_4510 : (!torch.int) -> !torch.list + %true_4511 = torch.constant.bool true + %none_4512 = torch.constant.none + %4375 = torch.aten.mean.dim %4373, %4374, %true_4511, %none_4512 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_4513 = torch.constant.float 9.9999997473787516E-6 + %int1_4514 = torch.constant.int 1 + %4376 = torch.aten.add.Scalar %4375, %float9.999990e-06_4513, %int1_4514 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %4377 = torch.aten.rsqrt %4376 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %4378 = torch.aten.mul.Tensor %4372, %4377 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_4515 = torch.constant.int 5 + %4379 = torch.prims.convert_element_type %4378, %int5_4515 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %4380 = torch.aten.mul.Tensor %254, %4379 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_4516 = torch.constant.int 5 + %4381 = torch.prims.convert_element_type %4380, %int5_4516 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_4517 = torch.constant.int -2 + %int-1_4518 = torch.constant.int -1 + %4382 = torch.aten.transpose.int %255, %int-2_4517, %int-1_4518 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_4519 = torch.constant.int 5 + %4383 = torch.prims.convert_element_type %4382, %int5_4519 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_4520 = torch.constant.int 4 + %int4096_4521 = torch.constant.int 4096 + %4384 = torch.prim.ListConstruct %int4_4520, %int4096_4521 : (!torch.int, !torch.int) -> !torch.list + %4385 = torch.aten.view %4381, %4384 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4386 = torch.aten.mm %4385, %4383 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_4522 = torch.constant.int 4 %int1_4523 = torch.constant.int 1 - %3745 = torch.prim.ListConstruct %int4_4522, %int1_4523 : (!torch.int, !torch.int) -> !torch.list - %int1_4524 = torch.constant.int 1 - %int1_4525 = torch.constant.int 1 - %3746 = torch.prim.ListConstruct %int1_4524, %int1_4525 : (!torch.int, !torch.int) -> !torch.list - %int4_4526 = torch.constant.int 4 - %int0_4527 = torch.constant.int 0 - %cpu_4528 = torch.constant.device "cpu" - %false_4529 = torch.constant.bool false - %3747 = torch.aten.empty_strided %3745, %3746, %int4_4526, %int0_4527, %cpu_4528, %false_4529 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int16_4530 = torch.constant.int 16 - %3748 = torch.aten.fill.Scalar %3747, %int16_4530 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_4531 = torch.constant.int 4 - %int1_4532 = torch.constant.int 1 - %3749 = torch.prim.ListConstruct %int4_4531, %int1_4532 : (!torch.int, !torch.int) -> !torch.list - %3750 = torch.aten.repeat %3744, %3749 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_4533 = torch.constant.int 32 - %3751 = torch.aten.mul.Scalar %3740, %int32_4533 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4534 = torch.constant.int 1 - %3752 = torch.aten.add.Tensor %3751, %3748, %int1_4534 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_4535 = torch.constant.int 2 - %3753 = torch.aten.mul.Scalar %3752, %int2_4535 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4536 = torch.constant.int 1 - %3754 = torch.aten.add.Tensor %3753, %3750, %int1_4536 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_4537 = torch.constant.int 32 - %3755 = torch.aten.mul.Scalar %3754, %int32_4537 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4538 = torch.constant.int 1 - %3756 = torch.aten.add.Tensor %3755, %3742, %int1_4538 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %3757 = torch.prim.ListConstruct %3756 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_4539 = torch.constant.bool false - %3758 = torch.aten.index_put %3737, %3757, %3689, %false_4539 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3758, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_4540 = torch.constant.int 32 - %int2_4541 = torch.constant.int 2 - %int32_4542 = torch.constant.int 32 - %int8_4543 = torch.constant.int 8 + %int4096_4524 = torch.constant.int 4096 + %4387 = torch.prim.ListConstruct %int4_4522, %int1_4523, %int4096_4524 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4388 = torch.aten.view %4386, %4387 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_4525 = torch.constant.int -2 + %int-1_4526 = torch.constant.int -1 + %4389 = torch.aten.transpose.int %256, %int-2_4525, %int-1_4526 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_4527 = torch.constant.int 5 + %4390 = torch.prims.convert_element_type %4389, %int5_4527 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_4528 = torch.constant.int 4 + %int4096_4529 = torch.constant.int 4096 + %4391 = torch.prim.ListConstruct %int4_4528, %int4096_4529 : (!torch.int, !torch.int) -> !torch.list + %4392 = torch.aten.view %4381, %4391 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4393 = torch.aten.mm %4392, %4390 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_4530 = torch.constant.int 4 + %int1_4531 = torch.constant.int 1 + %int1024_4532 = torch.constant.int 1024 + %4394 = torch.prim.ListConstruct %int4_4530, %int1_4531, %int1024_4532 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4395 = torch.aten.view %4393, %4394 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_4533 = torch.constant.int -2 + %int-1_4534 = torch.constant.int -1 + %4396 = torch.aten.transpose.int %257, %int-2_4533, %int-1_4534 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_4535 = torch.constant.int 5 + %4397 = torch.prims.convert_element_type %4396, %int5_4535 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_4536 = torch.constant.int 4 + %int4096_4537 = torch.constant.int 4096 + %4398 = torch.prim.ListConstruct %int4_4536, %int4096_4537 : (!torch.int, !torch.int) -> !torch.list + %4399 = torch.aten.view %4381, %4398 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4400 = torch.aten.mm %4399, %4397 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_4538 = torch.constant.int 4 + %int1_4539 = torch.constant.int 1 + %int1024_4540 = torch.constant.int 1024 + %4401 = torch.prim.ListConstruct %int4_4538, %int1_4539, %int1024_4540 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4402 = torch.aten.view %4400, %4401 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_4541 = torch.constant.int 4 + %int1_4542 = torch.constant.int 1 + %int32_4543 = torch.constant.int 32 %int128_4544 = torch.constant.int 128 - %3759 = torch.prim.ListConstruct %437, %int32_4540, %int2_4541, %int32_4542, %int8_4543, %int128_4544 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3760 = torch.aten.view %3758, %3759 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3760, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_4545 = torch.constant.int 2097152 - %3761 = torch.prim.ListConstruct %437, %int2097152_4545 : (!torch.int, !torch.int) -> !torch.list - %3762 = torch.aten.view %3760, %3761 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3762, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_4546 = torch.constant.int 4 - %3763 = torch.prim.ListConstruct %int4_4546, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_4547 = torch.constant.int 1 - %3764 = torch.prim.ListConstruct %358, %int1_4547 : (!torch.int, !torch.int) -> !torch.list - %int4_4548 = torch.constant.int 4 - %int0_4549 = torch.constant.int 0 - %cpu_4550 = torch.constant.device "cpu" - %false_4551 = torch.constant.bool false - %3765 = torch.aten.empty_strided %3763, %3764, %int4_4548, %int0_4549, %cpu_4550, %false_4551 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3765, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int16_4552 = torch.constant.int 16 - %3766 = torch.aten.fill.Scalar %3765, %int16_4552 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3766, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_4553 = torch.constant.int 32 - %3767 = torch.aten.mul.Scalar %arg3, %int32_4553 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3767, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_4554 = torch.constant.int 1 - %3768 = torch.aten.add.Tensor %3767, %3766, %int1_4554 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3768, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_4555 = torch.constant.int 4 - %3769 = torch.aten.mul.int %int4_4555, %358 : !torch.int, !torch.int -> !torch.int - %3770 = torch.prim.ListConstruct %3769 : (!torch.int) -> !torch.list - %3771 = torch.aten.view %3768, %3770 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3771, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_4556 = torch.constant.int 32 - %int2_4557 = torch.constant.int 2 - %int32_4558 = torch.constant.int 32 - %int8_4559 = torch.constant.int 8 - %int128_4560 = torch.constant.int 128 - %3772 = torch.prim.ListConstruct %437, %int32_4556, %int2_4557, %int32_4558, %int8_4559, %int128_4560 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3773 = torch.aten.view %3762, %3772 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3773, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4561 = torch.constant.int 32 - %3774 = torch.aten.mul.int %437, %int32_4561 : !torch.int, !torch.int -> !torch.int - %int2_4562 = torch.constant.int 2 - %int32_4563 = torch.constant.int 32 - %int8_4564 = torch.constant.int 8 - %int128_4565 = torch.constant.int 128 - %3775 = torch.prim.ListConstruct %3774, %int2_4562, %int32_4563, %int8_4564, %int128_4565 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3776 = torch.aten.view %3773, %3775 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %3776, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_4566 = torch.constant.int 0 - %3777 = torch.aten.index_select %3776, %int0_4566, %3771 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %3777, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_4567 = torch.constant.int 4 + %4403 = torch.prim.ListConstruct %int4_4541, %int1_4542, %int32_4543, %int128_4544 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4404 = torch.aten.view %4388, %4403 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_4545 = torch.constant.int 4 + %int1_4546 = torch.constant.int 1 + %int8_4547 = torch.constant.int 8 + %int128_4548 = torch.constant.int 128 + %4405 = torch.prim.ListConstruct %int4_4545, %int1_4546, %int8_4547, %int128_4548 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4406 = torch.aten.view %4395, %4405 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_4549 = torch.constant.int 4 + %int1_4550 = torch.constant.int 1 + %int8_4551 = torch.constant.int 8 + %int128_4552 = torch.constant.int 128 + %4407 = torch.prim.ListConstruct %int4_4549, %int1_4550, %int8_4551, %int128_4552 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4408 = torch.aten.view %4402, %4407 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_4553 = torch.constant.int 1 + %int2_4554 = torch.constant.int 2 + %4409 = torch.aten.transpose.int %4404, %int1_4553, %int2_4554 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %4410 = torch.aten.mul.Tensor %4409, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_4555 = torch.constant.int 3 + %int0_4556 = torch.constant.int 0 + %int64_4557 = torch.constant.int 64 + %int1_4558 = torch.constant.int 1 + %4411 = torch.aten.slice.Tensor %4409, %int3_4555, %int0_4556, %int64_4557, %int1_4558 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_4559 = torch.constant.int 3 + %int64_4560 = torch.constant.int 64 + %int9223372036854775807_4561 = torch.constant.int 9223372036854775807 + %int1_4562 = torch.constant.int 1 + %4412 = torch.aten.slice.Tensor %4409, %int3_4559, %int64_4560, %int9223372036854775807_4561, %int1_4562 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %4413 = torch.aten.neg %4412 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %4414 = torch.prim.ListConstruct %4413, %4411 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_4563 = torch.constant.int -1 + %4415 = torch.aten.cat %4414, %int-1_4563 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %4416 = torch.aten.mul.Tensor %4415, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_4564 = torch.constant.int 1 + %4417 = torch.aten.add.Tensor %4410, %4416, %int1_4564 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_4565 = torch.constant.int 1 + %int2_4566 = torch.constant.int 2 + %4418 = torch.aten.transpose.int %4417, %int1_4565, %int2_4566 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_4567 = torch.constant.int 1 %int2_4568 = torch.constant.int 2 - %int32_4569 = torch.constant.int 32 - %int8_4570 = torch.constant.int 8 - %int128_4571 = torch.constant.int 128 - %3778 = torch.prim.ListConstruct %int4_4567, %358, %int2_4568, %int32_4569, %int8_4570, %int128_4571 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3779 = torch.aten.view %3777, %3778 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3779, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_4572 = torch.constant.int 0 - %int0_4573 = torch.constant.int 0 - %int9223372036854775807_4574 = torch.constant.int 9223372036854775807 - %int1_4575 = torch.constant.int 1 - %3780 = torch.aten.slice.Tensor %3779, %int0_4572, %int0_4573, %int9223372036854775807_4574, %int1_4575 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3780, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> + %4419 = torch.aten.transpose.int %4406, %int1_4567, %int2_4568 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %4420 = torch.aten.mul.Tensor %4419, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_4569 = torch.constant.int 3 + %int0_4570 = torch.constant.int 0 + %int64_4571 = torch.constant.int 64 + %int1_4572 = torch.constant.int 1 + %4421 = torch.aten.slice.Tensor %4419, %int3_4569, %int0_4570, %int64_4571, %int1_4572 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_4573 = torch.constant.int 3 + %int64_4574 = torch.constant.int 64 + %int9223372036854775807_4575 = torch.constant.int 9223372036854775807 %int1_4576 = torch.constant.int 1 - %int0_4577 = torch.constant.int 0 - %int9223372036854775807_4578 = torch.constant.int 9223372036854775807 + %4422 = torch.aten.slice.Tensor %4419, %int3_4573, %int64_4574, %int9223372036854775807_4575, %int1_4576 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %4423 = torch.aten.neg %4422 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %4424 = torch.prim.ListConstruct %4423, %4421 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_4577 = torch.constant.int -1 + %4425 = torch.aten.cat %4424, %int-1_4577 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %4426 = torch.aten.mul.Tensor %4425, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_4578 = torch.constant.int 1 + %4427 = torch.aten.add.Tensor %4420, %4426, %int1_4578 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> %int1_4579 = torch.constant.int 1 - %3781 = torch.aten.slice.Tensor %3780, %int1_4576, %int0_4577, %int9223372036854775807_4578, %int1_4579 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3781, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> %int2_4580 = torch.constant.int 2 - %int0_4581 = torch.constant.int 0 - %3782 = torch.aten.select.int %3781, %int2_4580, %int0_4581 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3782, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_4582 = torch.constant.int 32 - %3783 = torch.aten.mul.int %358, %int32_4582 : !torch.int, !torch.int -> !torch.int - %int2_4583 = torch.constant.int 2 - %int0_4584 = torch.constant.int 0 - %int1_4585 = torch.constant.int 1 - %3784 = torch.aten.slice.Tensor %3782, %int2_4583, %int0_4584, %3783, %int1_4585 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3784, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_4586 = torch.constant.int 0 - %3785 = torch.aten.clone %3784, %int0_4586 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3785, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %4428 = torch.aten.transpose.int %4427, %int1_4579, %int2_4580 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_4581 = torch.constant.int 32 + %4429 = torch.aten.floor_divide.Scalar %arg2, %int32_4581 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_4582 = torch.constant.int 1 + %4430 = torch.aten.unsqueeze %4429, %int1_4582 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_4583 = torch.constant.int 1 + %false_4584 = torch.constant.bool false + %4431 = torch.aten.gather %arg3, %int1_4583, %4430, %false_4584 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_4585 = torch.constant.int 4 + %int1_4586 = torch.constant.int 1 %int1_4587 = torch.constant.int 1 - %3786 = torch.aten.size.int %3781, %int1_4587 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int + %4432 = torch.prim.ListConstruct %int4_4585, %int1_4586, %int1_4587 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4433 = torch.aten.view %4431, %4432 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> %int32_4588 = torch.constant.int 32 - %3787 = torch.aten.mul.int %3786, %int32_4588 : !torch.int, !torch.int -> !torch.int + %4434 = torch.aten.remainder.Scalar %arg2, %int32_4588 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> %int4_4589 = torch.constant.int 4 - %int8_4590 = torch.constant.int 8 - %int128_4591 = torch.constant.int 128 - %3788 = torch.prim.ListConstruct %int4_4589, %3787, %int8_4590, %int128_4591 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3789 = torch.aten._unsafe_view %3785, %3788 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3789, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_4592 = torch.constant.int 0 - %int0_4593 = torch.constant.int 0 - %int9223372036854775807_4594 = torch.constant.int 9223372036854775807 - %int1_4595 = torch.constant.int 1 - %3790 = torch.aten.slice.Tensor %3789, %int0_4592, %int0_4593, %int9223372036854775807_4594, %int1_4595 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3790, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_4596 = torch.constant.int 0 - %int0_4597 = torch.constant.int 0 - %int9223372036854775807_4598 = torch.constant.int 9223372036854775807 - %int1_4599 = torch.constant.int 1 - %3791 = torch.aten.slice.Tensor %3779, %int0_4596, %int0_4597, %int9223372036854775807_4598, %int1_4599 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3791, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_4600 = torch.constant.int 1 - %int0_4601 = torch.constant.int 0 - %int9223372036854775807_4602 = torch.constant.int 9223372036854775807 + %int1_4590 = torch.constant.int 1 + %int1_4591 = torch.constant.int 1 + %4435 = torch.prim.ListConstruct %int4_4589, %int1_4590, %int1_4591 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4436 = torch.aten.view %4434, %4435 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_4592 = torch.constant.int 8 + %none_4593 = torch.constant.none + %none_4594 = torch.constant.none + %cpu_4595 = torch.constant.device "cpu" + %false_4596 = torch.constant.bool false + %4437 = torch.aten.arange %int8_4592, %none_4593, %none_4594, %cpu_4595, %false_4596 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_4597 = torch.constant.int 1 + %int1_4598 = torch.constant.int 1 + %int8_4599 = torch.constant.int 8 + %4438 = torch.prim.ListConstruct %int1_4597, %int1_4598, %int8_4599 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4439 = torch.aten.view %4437, %4438 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_4600 = torch.constant.none + %4440 = torch.aten.clone %258, %none_4600 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4441 = torch.aten.detach %4440 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4442 = torch.aten.detach %4441 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4443 = torch.aten.detach %4442 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_4601 = torch.constant.int 1 + %int1_4602 = torch.constant.int 1 %int1_4603 = torch.constant.int 1 - %3792 = torch.aten.slice.Tensor %3791, %int1_4600, %int0_4601, %int9223372036854775807_4602, %int1_4603 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3792, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_4604 = torch.constant.int 2 + %4444 = torch.prim.ListConstruct %int1_4601, %int1_4602, %int1_4603 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4445 = torch.aten.view %4443, %4444 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_4604 = torch.constant.int 32 + %4446 = torch.aten.mul.Scalar %4433, %int32_4604 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int18 = torch.constant.int 18 %int1_4605 = torch.constant.int 1 - %3793 = torch.aten.select.int %3792, %int2_4604, %int1_4605 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3793, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %4447 = torch.aten.add.Scalar %4446, %int18, %int1_4605 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> %int2_4606 = torch.constant.int 2 - %int0_4607 = torch.constant.int 0 - %int1_4608 = torch.constant.int 1 - %3794 = torch.aten.slice.Tensor %3793, %int2_4606, %int0_4607, %3783, %int1_4608 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3794, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_4609 = torch.constant.int 0 - %3795 = torch.aten.clone %3794, %int0_4609 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3795, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_4610 = torch.constant.int 1 - %3796 = torch.aten.size.int %3792, %int1_4610 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_4611 = torch.constant.int 32 - %3797 = torch.aten.mul.int %3796, %int32_4611 : !torch.int, !torch.int -> !torch.int - %int4_4612 = torch.constant.int 4 - %int8_4613 = torch.constant.int 8 - %int128_4614 = torch.constant.int 128 - %3798 = torch.prim.ListConstruct %int4_4612, %3797, %int8_4613, %int128_4614 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3799 = torch.aten._unsafe_view %3795, %3798 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3799, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_4615 = torch.constant.int 0 - %int0_4616 = torch.constant.int 0 - %int9223372036854775807_4617 = torch.constant.int 9223372036854775807 - %int1_4618 = torch.constant.int 1 - %3800 = torch.aten.slice.Tensor %3799, %int0_4615, %int0_4616, %int9223372036854775807_4617, %int1_4618 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3800, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_4619 = torch.constant.int -2 - %3801 = torch.aten.unsqueeze %3790, %int-2_4619 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3801, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_4620 = torch.constant.int 1 - %3802 = torch.aten.size.int %3789, %int1_4620 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_4621 = torch.constant.int 4 + %4448 = torch.aten.mul.Scalar %4447, %int2_4606 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4607 = torch.constant.int 1 + %4449 = torch.aten.add.Tensor %4448, %4445, %int1_4607 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_4608 = torch.constant.int 8 + %4450 = torch.aten.mul.Scalar %4449, %int8_4608 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4609 = torch.constant.int 1 + %4451 = torch.aten.add.Tensor %4450, %4439, %int1_4609 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_4610 = torch.constant.int 32 + %4452 = torch.aten.mul.Scalar %4451, %int32_4610 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_4611 = torch.constant.int 1 + %4453 = torch.aten.add.Tensor %4452, %4436, %int1_4611 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_4612 = torch.constant.int 5 + %4454 = torch.prims.convert_element_type %4428, %int5_4612 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_4613 = torch.constant.int 32 + %int2_4614 = torch.constant.int 2 + %int8_4615 = torch.constant.int 8 + %int32_4616 = torch.constant.int 32 + %int128_4617 = torch.constant.int 128 + %4455 = torch.prim.ListConstruct %456, %int32_4613, %int2_4614, %int8_4615, %int32_4616, %int128_4617 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4456 = torch.aten.view %4276, %4455 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4456, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_4618 = torch.constant.int 128 + %4457 = torch.prim.ListConstruct %596, %int128_4618 : (!torch.int, !torch.int) -> !torch.list + %4458 = torch.aten.view %4456, %4457 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4458, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %4459 = torch.prim.ListConstruct %4453 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_4619 = torch.constant.bool false + %4460 = torch.aten.index_put %4458, %4459, %4454, %false_4619 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4460, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_4620 = torch.constant.int 32 + %int2_4621 = torch.constant.int 2 %int8_4622 = torch.constant.int 8 - %int4_4623 = torch.constant.int 4 + %int32_4623 = torch.constant.int 32 %int128_4624 = torch.constant.int 128 - %3803 = torch.prim.ListConstruct %int4_4621, %3802, %int8_4622, %int4_4623, %int128_4624 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4625 = torch.constant.bool false - %3804 = torch.aten.expand %3801, %3803, %false_4625 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3804, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_4626 = torch.constant.int 0 - %3805 = torch.aten.clone %3804, %int0_4626 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3805, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4627 = torch.constant.int 4 - %int32_4628 = torch.constant.int 32 - %int128_4629 = torch.constant.int 128 - %3806 = torch.prim.ListConstruct %int4_4627, %3802, %int32_4628, %int128_4629 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3807 = torch.aten._unsafe_view %3805, %3806 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3807, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_4630 = torch.constant.int -2 - %3808 = torch.aten.unsqueeze %3800, %int-2_4630 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %3808, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_4631 = torch.constant.int 1 - %3809 = torch.aten.size.int %3799, %int1_4631 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_4632 = torch.constant.int 4 - %int8_4633 = torch.constant.int 8 - %int4_4634 = torch.constant.int 4 - %int128_4635 = torch.constant.int 128 - %3810 = torch.prim.ListConstruct %int4_4632, %3809, %int8_4633, %int4_4634, %int128_4635 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4636 = torch.constant.bool false - %3811 = torch.aten.expand %3808, %3810, %false_4636 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3811, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_4637 = torch.constant.int 0 - %3812 = torch.aten.clone %3811, %int0_4637 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %3812, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4638 = torch.constant.int 4 - %int32_4639 = torch.constant.int 32 - %int128_4640 = torch.constant.int 128 - %3813 = torch.prim.ListConstruct %int4_4638, %3809, %int32_4639, %int128_4640 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3814 = torch.aten._unsafe_view %3812, %3813 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %3814, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_4641 = torch.constant.int 1 - %int2_4642 = torch.constant.int 2 - %3815 = torch.aten.transpose.int %3695, %int1_4641, %int2_4642 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_4643 = torch.constant.int 1 - %int2_4644 = torch.constant.int 2 - %3816 = torch.aten.transpose.int %3807, %int1_4643, %int2_4644 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3816, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_4645 = torch.constant.int 1 - %int2_4646 = torch.constant.int 2 - %3817 = torch.aten.transpose.int %3814, %int1_4645, %int2_4646 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %3817, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_4647 = torch.constant.float 0.000000e+00 - %false_4648 = torch.constant.bool false - %none_4649 = torch.constant.none - %3818:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%3815, %3816, %3817, %float0.000000e00_4647, %false_4648, %368, %none_4649) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_4650 = torch.constant.int 1 - %int2_4651 = torch.constant.int 2 - %3819 = torch.aten.transpose.int %3818#0, %int1_4650, %int2_4651 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_4652 = torch.constant.int 4 - %int1_4653 = torch.constant.int 1 - %int4096_4654 = torch.constant.int 4096 - %3820 = torch.prim.ListConstruct %int4_4652, %int1_4653, %int4096_4654 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3821 = torch.aten.view %3819, %3820 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_4655 = torch.constant.int -2 - %int-1_4656 = torch.constant.int -1 - %3822 = torch.aten.transpose.int %183, %int-2_4655, %int-1_4656 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_4657 = torch.constant.int 4 - %int4096_4658 = torch.constant.int 4096 - %3823 = torch.prim.ListConstruct %int4_4657, %int4096_4658 : (!torch.int, !torch.int) -> !torch.list - %3824 = torch.aten.view %3821, %3823 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3825 = torch.aten.mm %3824, %3822 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_4659 = torch.constant.int 4 - %int1_4660 = torch.constant.int 1 - %int4096_4661 = torch.constant.int 4096 - %3826 = torch.prim.ListConstruct %int4_4659, %int1_4660, %int4096_4661 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3827 = torch.aten.view %3825, %3826 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_4662 = torch.constant.int 1 - %3828 = torch.aten.add.Tensor %3655, %3827, %int1_4662 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_4663 = torch.constant.int 6 - %3829 = torch.prims.convert_element_type %3828, %int6_4663 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_4664 = torch.constant.int 2 - %3830 = torch.aten.pow.Tensor_Scalar %3829, %int2_4664 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_4665 = torch.constant.int -1 - %3831 = torch.prim.ListConstruct %int-1_4665 : (!torch.int) -> !torch.list - %true_4666 = torch.constant.bool true - %none_4667 = torch.constant.none - %3832 = torch.aten.mean.dim %3830, %3831, %true_4666, %none_4667 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_4668 = torch.constant.float 9.9999997473787516E-6 - %int1_4669 = torch.constant.int 1 - %3833 = torch.aten.add.Scalar %3832, %float9.999990e-06_4668, %int1_4669 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %3834 = torch.aten.rsqrt %3833 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %3835 = torch.aten.mul.Tensor %3829, %3834 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_4670 = torch.constant.int 5 - %3836 = torch.prims.convert_element_type %3835, %int5_4670 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %3837 = torch.aten.mul.Tensor %184, %3836 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_4671 = torch.constant.int 5 - %3838 = torch.prims.convert_element_type %3837, %int5_4671 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_4672 = torch.constant.int -2 - %int-1_4673 = torch.constant.int -1 - %3839 = torch.aten.transpose.int %185, %int-2_4672, %int-1_4673 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %4461 = torch.prim.ListConstruct %456, %int32_4620, %int2_4621, %int8_4622, %int32_4623, %int128_4624 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4462 = torch.aten.view %4460, %4461 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4462, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_4625 = torch.constant.int 2097152 + %4463 = torch.prim.ListConstruct %456, %int2097152_4625 : (!torch.int, !torch.int) -> !torch.list + %4464 = torch.aten.view %4462, %4463 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4464, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_4626 = torch.constant.int 32 + %int2_4627 = torch.constant.int 2 + %int8_4628 = torch.constant.int 8 + %int32_4629 = torch.constant.int 32 + %int128_4630 = torch.constant.int 128 + %4465 = torch.prim.ListConstruct %456, %int32_4626, %int2_4627, %int8_4628, %int32_4629, %int128_4630 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4466 = torch.aten.view %4464, %4465 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4466, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_4631 = torch.constant.int 128 + %4467 = torch.prim.ListConstruct %596, %int128_4631 : (!torch.int, !torch.int) -> !torch.list + %4468 = torch.aten.view %4466, %4467 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4468, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_4632 = torch.constant.none + %4469 = torch.aten.clone %259, %none_4632 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4470 = torch.aten.detach %4469 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4471 = torch.aten.detach %4470 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4472 = torch.aten.detach %4471 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_4633 = torch.constant.int 1 + %int1_4634 = torch.constant.int 1 + %int1_4635 = torch.constant.int 1 + %4473 = torch.prim.ListConstruct %int1_4633, %int1_4634, %int1_4635 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4474 = torch.aten.view %4472, %4473 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_4636 = torch.constant.int 32 + %4475 = torch.aten.mul.Scalar %4433, %int32_4636 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int18_4637 = torch.constant.int 18 + %int1_4638 = torch.constant.int 1 + %4476 = torch.aten.add.Scalar %4475, %int18_4637, %int1_4638 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_4639 = torch.constant.int 2 + %4477 = torch.aten.mul.Scalar %4476, %int2_4639 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4640 = torch.constant.int 1 + %4478 = torch.aten.add.Tensor %4477, %4474, %int1_4640 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_4641 = torch.constant.int 8 + %4479 = torch.aten.mul.Scalar %4478, %int8_4641 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4642 = torch.constant.int 1 + %4480 = torch.aten.add.Tensor %4479, %4439, %int1_4642 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_4643 = torch.constant.int 32 + %4481 = torch.aten.mul.Scalar %4480, %int32_4643 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_4644 = torch.constant.int 1 + %4482 = torch.aten.add.Tensor %4481, %4436, %int1_4644 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_4645 = torch.constant.int 5 + %4483 = torch.prims.convert_element_type %4408, %int5_4645 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %4484 = torch.prim.ListConstruct %4482 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_4646 = torch.constant.bool false + %4485 = torch.aten.index_put %4468, %4484, %4483, %false_4646 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4485, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_4647 = torch.constant.int 32 + %int2_4648 = torch.constant.int 2 + %int8_4649 = torch.constant.int 8 + %int32_4650 = torch.constant.int 32 + %int128_4651 = torch.constant.int 128 + %4486 = torch.prim.ListConstruct %456, %int32_4647, %int2_4648, %int8_4649, %int32_4650, %int128_4651 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4487 = torch.aten.view %4485, %4486 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4487, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_4652 = torch.constant.int 2097152 + %4488 = torch.prim.ListConstruct %456, %int2097152_4652 : (!torch.int, !torch.int) -> !torch.list + %4489 = torch.aten.view %4487, %4488 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4489, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_4653 = torch.constant.none + %4490 = torch.aten.clone %260, %none_4653 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4491 = torch.aten.detach %4490 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4492 = torch.aten.detach %4491 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4493 = torch.aten.detach %4492 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_4654 = torch.constant.none + %4494 = torch.aten.clone %261, %none_4654 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4495 = torch.aten.detach %4494 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4496 = torch.aten.detach %4495 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4497 = torch.aten.detach %4496 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_4655 = torch.constant.none + %4498 = torch.aten.clone %262, %none_4655 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4499 = torch.aten.detach %4498 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4500 = torch.aten.detach %4499 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4501 = torch.aten.detach %4500 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_4656 = torch.constant.int 32 + %int2_4657 = torch.constant.int 2 + %int8_4658 = torch.constant.int 8 + %int32_4659 = torch.constant.int 32 + %int128_4660 = torch.constant.int 128 + %4502 = torch.prim.ListConstruct %456, %int32_4656, %int2_4657, %int8_4658, %int32_4659, %int128_4660 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4503 = torch.aten.view %4489, %4502 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4503, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %4504 = torch_c.to_builtin_tensor %4503 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %4505 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_4661 = tensor.cast %4505 : tensor<4x?xi64> to tensor + %4506 = torch_c.to_builtin_tensor %4493 : !torch.vtensor<[],si64> -> tensor + %4507 = torch_c.to_builtin_tensor %4497 : !torch.vtensor<[],si64> -> tensor + %4508 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%4504, %cast_4661, %4506, %4507) : (tensor, tensor, tensor, tensor) -> tensor + %cast_4662 = tensor.cast %4508 : tensor to tensor<4x?x8x32x128xf16> + %4509 = torch_c.from_builtin_tensor %cast_4662 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %4509, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %4510 = torch_c.to_builtin_tensor %4503 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %4511 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_4663 = tensor.cast %4511 : tensor<4x?xi64> to tensor + %4512 = torch_c.to_builtin_tensor %4493 : !torch.vtensor<[],si64> -> tensor + %4513 = torch_c.to_builtin_tensor %4501 : !torch.vtensor<[],si64> -> tensor + %4514 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%4510, %cast_4663, %4512, %4513) : (tensor, tensor, tensor, tensor) -> tensor + %cast_4664 = tensor.cast %4514 : tensor to tensor<4x?x8x32x128xf16> + %4515 = torch_c.from_builtin_tensor %cast_4664 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %4515, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_4665 = torch.constant.int 2 + %int3_4666 = torch.constant.int 3 + %4516 = torch.aten.transpose.int %4509, %int2_4665, %int3_4666 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4516, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_4667 = torch.constant.int 0 + %4517 = torch.aten.clone %4516, %int0_4667 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4517, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_4668 = torch.constant.int 4 + %int8_4669 = torch.constant.int 8 + %int128_4670 = torch.constant.int 128 + %4518 = torch.prim.ListConstruct %int4_4668, %457, %int8_4669, %int128_4670 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4519 = torch.aten._unsafe_view %4517, %4518 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4519, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_4671 = torch.constant.int 2 + %int3_4672 = torch.constant.int 3 + %4520 = torch.aten.transpose.int %4515, %int2_4671, %int3_4672 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4520, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_4673 = torch.constant.int 0 + %4521 = torch.aten.clone %4520, %int0_4673 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4521, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int4_4674 = torch.constant.int 4 - %int4096_4675 = torch.constant.int 4096 - %3840 = torch.prim.ListConstruct %int4_4674, %int4096_4675 : (!torch.int, !torch.int) -> !torch.list - %3841 = torch.aten.view %3838, %3840 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3842 = torch.aten.mm %3841, %3839 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_4676 = torch.constant.int 4 - %int1_4677 = torch.constant.int 1 - %int14336_4678 = torch.constant.int 14336 - %3843 = torch.prim.ListConstruct %int4_4676, %int1_4677, %int14336_4678 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3844 = torch.aten.view %3842, %3843 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %3845 = torch.aten.silu %3844 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_4679 = torch.constant.int -2 - %int-1_4680 = torch.constant.int -1 - %3846 = torch.aten.transpose.int %186, %int-2_4679, %int-1_4680 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_4681 = torch.constant.int 4 - %int4096_4682 = torch.constant.int 4096 - %3847 = torch.prim.ListConstruct %int4_4681, %int4096_4682 : (!torch.int, !torch.int) -> !torch.list - %3848 = torch.aten.view %3838, %3847 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3849 = torch.aten.mm %3848, %3846 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_4683 = torch.constant.int 4 - %int1_4684 = torch.constant.int 1 - %int14336_4685 = torch.constant.int 14336 - %3850 = torch.prim.ListConstruct %int4_4683, %int1_4684, %int14336_4685 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3851 = torch.aten.view %3849, %3850 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %3852 = torch.aten.mul.Tensor %3845, %3851 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_4686 = torch.constant.int -2 - %int-1_4687 = torch.constant.int -1 - %3853 = torch.aten.transpose.int %187, %int-2_4686, %int-1_4687 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int8_4675 = torch.constant.int 8 + %int128_4676 = torch.constant.int 128 + %4522 = torch.prim.ListConstruct %int4_4674, %457, %int8_4675, %int128_4676 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4523 = torch.aten._unsafe_view %4521, %4522 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4523, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_4677 = torch.constant.int -2 + %4524 = torch.aten.unsqueeze %4519, %int-2_4677 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4524, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_4678 = torch.constant.int 4 + %int8_4679 = torch.constant.int 8 + %int4_4680 = torch.constant.int 4 + %int128_4681 = torch.constant.int 128 + %4525 = torch.prim.ListConstruct %int4_4678, %457, %int8_4679, %int4_4680, %int128_4681 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_4682 = torch.constant.bool false + %4526 = torch.aten.expand %4524, %4525, %false_4682 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4526, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_4683 = torch.constant.int 0 + %4527 = torch.aten.clone %4526, %int0_4683 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4527, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_4684 = torch.constant.int 4 + %int32_4685 = torch.constant.int 32 + %int128_4686 = torch.constant.int 128 + %4528 = torch.prim.ListConstruct %int4_4684, %457, %int32_4685, %int128_4686 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4529 = torch.aten._unsafe_view %4527, %4528 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4529, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_4687 = torch.constant.int -2 + %4530 = torch.aten.unsqueeze %4523, %int-2_4687 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4530, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> %int4_4688 = torch.constant.int 4 - %int14336_4689 = torch.constant.int 14336 - %3854 = torch.prim.ListConstruct %int4_4688, %int14336_4689 : (!torch.int, !torch.int) -> !torch.list - %3855 = torch.aten.view %3852, %3854 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %3856 = torch.aten.mm %3855, %3853 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int8_4689 = torch.constant.int 8 %int4_4690 = torch.constant.int 4 - %int1_4691 = torch.constant.int 1 - %int4096_4692 = torch.constant.int 4096 - %3857 = torch.prim.ListConstruct %int4_4690, %int1_4691, %int4096_4692 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3858 = torch.aten.view %3856, %3857 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_4693 = torch.constant.int 1 - %3859 = torch.aten.add.Tensor %3828, %3858, %int1_4693 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_4694 = torch.constant.int 6 - %3860 = torch.prims.convert_element_type %3859, %int6_4694 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_4695 = torch.constant.int 2 - %3861 = torch.aten.pow.Tensor_Scalar %3860, %int2_4695 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_4696 = torch.constant.int -1 - %3862 = torch.prim.ListConstruct %int-1_4696 : (!torch.int) -> !torch.list - %true_4697 = torch.constant.bool true - %none_4698 = torch.constant.none - %3863 = torch.aten.mean.dim %3861, %3862, %true_4697, %none_4698 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_4699 = torch.constant.float 9.9999997473787516E-6 - %int1_4700 = torch.constant.int 1 - %3864 = torch.aten.add.Scalar %3863, %float9.999990e-06_4699, %int1_4700 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %3865 = torch.aten.rsqrt %3864 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %3866 = torch.aten.mul.Tensor %3860, %3865 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_4701 = torch.constant.int 5 - %3867 = torch.prims.convert_element_type %3866, %int5_4701 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %3868 = torch.aten.mul.Tensor %188, %3867 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_4702 = torch.constant.int 5 - %3869 = torch.prims.convert_element_type %3868, %int5_4702 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_4703 = torch.constant.int -2 - %int-1_4704 = torch.constant.int -1 - %3870 = torch.aten.transpose.int %189, %int-2_4703, %int-1_4704 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_4705 = torch.constant.int 4 - %int4096_4706 = torch.constant.int 4096 - %3871 = torch.prim.ListConstruct %int4_4705, %int4096_4706 : (!torch.int, !torch.int) -> !torch.list - %3872 = torch.aten.view %3869, %3871 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3873 = torch.aten.mm %3872, %3870 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_4707 = torch.constant.int 4 - %int1_4708 = torch.constant.int 1 - %int4096_4709 = torch.constant.int 4096 - %3874 = torch.prim.ListConstruct %int4_4707, %int1_4708, %int4096_4709 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3875 = torch.aten.view %3873, %3874 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_4710 = torch.constant.int -2 - %int-1_4711 = torch.constant.int -1 - %3876 = torch.aten.transpose.int %190, %int-2_4710, %int-1_4711 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_4712 = torch.constant.int 4 - %int4096_4713 = torch.constant.int 4096 - %3877 = torch.prim.ListConstruct %int4_4712, %int4096_4713 : (!torch.int, !torch.int) -> !torch.list - %3878 = torch.aten.view %3869, %3877 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3879 = torch.aten.mm %3878, %3876 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_4714 = torch.constant.int 4 - %int1_4715 = torch.constant.int 1 - %int1024_4716 = torch.constant.int 1024 - %3880 = torch.prim.ListConstruct %int4_4714, %int1_4715, %int1024_4716 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3881 = torch.aten.view %3879, %3880 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_4717 = torch.constant.int -2 - %int-1_4718 = torch.constant.int -1 - %3882 = torch.aten.transpose.int %191, %int-2_4717, %int-1_4718 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_4719 = torch.constant.int 4 - %int4096_4720 = torch.constant.int 4096 - %3883 = torch.prim.ListConstruct %int4_4719, %int4096_4720 : (!torch.int, !torch.int) -> !torch.list - %3884 = torch.aten.view %3869, %3883 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %3885 = torch.aten.mm %3884, %3882 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_4721 = torch.constant.int 4 - %int1_4722 = torch.constant.int 1 - %int1024_4723 = torch.constant.int 1024 - %3886 = torch.prim.ListConstruct %int4_4721, %int1_4722, %int1024_4723 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3887 = torch.aten.view %3885, %3886 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_4724 = torch.constant.int 4 - %int1_4725 = torch.constant.int 1 - %int32_4726 = torch.constant.int 32 - %int128_4727 = torch.constant.int 128 - %3888 = torch.prim.ListConstruct %int4_4724, %int1_4725, %int32_4726, %int128_4727 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3889 = torch.aten.view %3875, %3888 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_4728 = torch.constant.int 4 - %int1_4729 = torch.constant.int 1 - %int8_4730 = torch.constant.int 8 - %int128_4731 = torch.constant.int 128 - %3890 = torch.prim.ListConstruct %int4_4728, %int1_4729, %int8_4730, %int128_4731 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3891 = torch.aten.view %3881, %3890 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int128_4691 = torch.constant.int 128 + %4531 = torch.prim.ListConstruct %int4_4688, %457, %int8_4689, %int4_4690, %int128_4691 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_4692 = torch.constant.bool false + %4532 = torch.aten.expand %4530, %4531, %false_4692 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4532, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_4693 = torch.constant.int 0 + %4533 = torch.aten.clone %4532, %int0_4693 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4533, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_4694 = torch.constant.int 4 + %int32_4695 = torch.constant.int 32 + %int128_4696 = torch.constant.int 128 + %4534 = torch.prim.ListConstruct %int4_4694, %457, %int32_4695, %int128_4696 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4535 = torch.aten._unsafe_view %4533, %4534 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4535, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_4697 = torch.constant.int 1 + %int2_4698 = torch.constant.int 2 + %4536 = torch.aten.transpose.int %4418, %int1_4697, %int2_4698 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_4699 = torch.constant.int 1 + %int2_4700 = torch.constant.int 2 + %4537 = torch.aten.transpose.int %4529, %int1_4699, %int2_4700 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4537, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_4701 = torch.constant.int 1 + %int2_4702 = torch.constant.int 2 + %4538 = torch.aten.transpose.int %4535, %int1_4701, %int2_4702 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4538, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_4703 = torch.constant.float 0.000000e+00 + %false_4704 = torch.constant.bool false + %none_4705 = torch.constant.none + %4539:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4536, %4537, %4538, %float0.000000e00_4703, %false_4704, %470, %none_4705) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_4706 = torch.constant.int 1 + %int2_4707 = torch.constant.int 2 + %4540 = torch.aten.transpose.int %4539#0, %int1_4706, %int2_4707 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_4708 = torch.constant.int 4 + %int1_4709 = torch.constant.int 1 + %int4096_4710 = torch.constant.int 4096 + %4541 = torch.prim.ListConstruct %int4_4708, %int1_4709, %int4096_4710 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4542 = torch.aten.view %4540, %4541 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_4711 = torch.constant.int -2 + %int-1_4712 = torch.constant.int -1 + %4543 = torch.aten.transpose.int %263, %int-2_4711, %int-1_4712 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_4713 = torch.constant.int 5 + %4544 = torch.prims.convert_element_type %4543, %int5_4713 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_4714 = torch.constant.int 4 + %int4096_4715 = torch.constant.int 4096 + %4545 = torch.prim.ListConstruct %int4_4714, %int4096_4715 : (!torch.int, !torch.int) -> !torch.list + %4546 = torch.aten.view %4542, %4545 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4547 = torch.aten.mm %4546, %4544 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_4716 = torch.constant.int 4 + %int1_4717 = torch.constant.int 1 + %int4096_4718 = torch.constant.int 4096 + %4548 = torch.prim.ListConstruct %int4_4716, %int1_4717, %int4096_4718 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4549 = torch.aten.view %4547, %4548 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_4719 = torch.constant.int 1 + %4550 = torch.aten.add.Tensor %4371, %4549, %int1_4719 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_4720 = torch.constant.int 6 + %4551 = torch.prims.convert_element_type %4550, %int6_4720 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_4721 = torch.constant.int 2 + %4552 = torch.aten.pow.Tensor_Scalar %4551, %int2_4721 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_4722 = torch.constant.int -1 + %4553 = torch.prim.ListConstruct %int-1_4722 : (!torch.int) -> !torch.list + %true_4723 = torch.constant.bool true + %none_4724 = torch.constant.none + %4554 = torch.aten.mean.dim %4552, %4553, %true_4723, %none_4724 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_4725 = torch.constant.float 9.9999997473787516E-6 + %int1_4726 = torch.constant.int 1 + %4555 = torch.aten.add.Scalar %4554, %float9.999990e-06_4725, %int1_4726 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %4556 = torch.aten.rsqrt %4555 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %4557 = torch.aten.mul.Tensor %4551, %4556 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_4727 = torch.constant.int 5 + %4558 = torch.prims.convert_element_type %4557, %int5_4727 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %4559 = torch.aten.mul.Tensor %264, %4558 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_4728 = torch.constant.int 5 + %4560 = torch.prims.convert_element_type %4559, %int5_4728 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_4729 = torch.constant.int -2 + %int-1_4730 = torch.constant.int -1 + %4561 = torch.aten.transpose.int %265, %int-2_4729, %int-1_4730 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_4731 = torch.constant.int 5 + %4562 = torch.prims.convert_element_type %4561, %int5_4731 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> %int4_4732 = torch.constant.int 4 - %int1_4733 = torch.constant.int 1 - %int8_4734 = torch.constant.int 8 - %int128_4735 = torch.constant.int 128 - %3892 = torch.prim.ListConstruct %int4_4732, %int1_4733, %int8_4734, %int128_4735 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3893 = torch.aten.view %3887, %3892 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_4736 = torch.constant.int 6 - %3894 = torch.prims.convert_element_type %3889, %int6_4736 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %3895 = torch_c.to_builtin_tensor %3894 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %3896 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %3897 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%3895, %3896) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %3898 = torch_c.from_builtin_tensor %3897 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_4737 = torch.constant.int 5 - %3899 = torch.prims.convert_element_type %3898, %int5_4737 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_4738 = torch.constant.int 6 - %3900 = torch.prims.convert_element_type %3891, %int6_4738 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %3901 = torch_c.to_builtin_tensor %3900 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %3902 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %3903 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%3901, %3902) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %3904 = torch_c.from_builtin_tensor %3903 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> + %int4096_4733 = torch.constant.int 4096 + %4563 = torch.prim.ListConstruct %int4_4732, %int4096_4733 : (!torch.int, !torch.int) -> !torch.list + %4564 = torch.aten.view %4560, %4563 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4565 = torch.aten.mm %4564, %4562 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_4734 = torch.constant.int 4 + %int1_4735 = torch.constant.int 1 + %int14336_4736 = torch.constant.int 14336 + %4566 = torch.prim.ListConstruct %int4_4734, %int1_4735, %int14336_4736 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4567 = torch.aten.view %4565, %4566 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %4568 = torch.aten.silu %4567 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_4737 = torch.constant.int -2 + %int-1_4738 = torch.constant.int -1 + %4569 = torch.aten.transpose.int %266, %int-2_4737, %int-1_4738 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> %int5_4739 = torch.constant.int 5 - %3905 = torch.prims.convert_element_type %3904, %int5_4739 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_4740 = torch.constant.int 32 - %3906 = torch.aten.floor_divide.Scalar %arg2, %int32_4740 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_4741 = torch.constant.int 1 - %3907 = torch.aten.unsqueeze %3906, %int1_4741 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4742 = torch.constant.int 1 - %false_4743 = torch.constant.bool false - %3908 = torch.aten.gather %arg3, %int1_4742, %3907, %false_4743 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_4744 = torch.constant.int 32 - %3909 = torch.aten.remainder.Scalar %arg2, %int32_4744 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_4745 = torch.constant.int 1 - %3910 = torch.aten.unsqueeze %3909, %int1_4745 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_4746 = torch.constant.none - %3911 = torch.aten.clone %192, %none_4746 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_4747 = torch.constant.int 0 - %3912 = torch.aten.unsqueeze %3911, %int0_4747 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %4570 = torch.prims.convert_element_type %4569, %int5_4739 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_4740 = torch.constant.int 4 + %int4096_4741 = torch.constant.int 4096 + %4571 = torch.prim.ListConstruct %int4_4740, %int4096_4741 : (!torch.int, !torch.int) -> !torch.list + %4572 = torch.aten.view %4560, %4571 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4573 = torch.aten.mm %4572, %4570 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_4742 = torch.constant.int 4 + %int1_4743 = torch.constant.int 1 + %int14336_4744 = torch.constant.int 14336 + %4574 = torch.prim.ListConstruct %int4_4742, %int1_4743, %int14336_4744 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4575 = torch.aten.view %4573, %4574 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %4576 = torch.aten.mul.Tensor %4568, %4575 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_4745 = torch.constant.int -2 + %int-1_4746 = torch.constant.int -1 + %4577 = torch.aten.transpose.int %267, %int-2_4745, %int-1_4746 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_4747 = torch.constant.int 5 + %4578 = torch.prims.convert_element_type %4577, %int5_4747 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> %int4_4748 = torch.constant.int 4 - %int1_4749 = torch.constant.int 1 - %3913 = torch.prim.ListConstruct %int4_4748, %int1_4749 : (!torch.int, !torch.int) -> !torch.list - %int1_4750 = torch.constant.int 1 + %int14336_4749 = torch.constant.int 14336 + %4579 = torch.prim.ListConstruct %int4_4748, %int14336_4749 : (!torch.int, !torch.int) -> !torch.list + %4580 = torch.aten.view %4576, %4579 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %4581 = torch.aten.mm %4580, %4578 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_4750 = torch.constant.int 4 %int1_4751 = torch.constant.int 1 - %3914 = torch.prim.ListConstruct %int1_4750, %int1_4751 : (!torch.int, !torch.int) -> !torch.list - %int4_4752 = torch.constant.int 4 - %int0_4753 = torch.constant.int 0 - %cpu_4754 = torch.constant.device "cpu" - %false_4755 = torch.constant.bool false - %3915 = torch.aten.empty_strided %3913, %3914, %int4_4752, %int0_4753, %cpu_4754, %false_4755 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int17 = torch.constant.int 17 - %3916 = torch.aten.fill.Scalar %3915, %int17 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_4756 = torch.constant.int 4 - %int1_4757 = torch.constant.int 1 - %3917 = torch.prim.ListConstruct %int4_4756, %int1_4757 : (!torch.int, !torch.int) -> !torch.list - %3918 = torch.aten.repeat %3912, %3917 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_4758 = torch.constant.int 32 - %3919 = torch.aten.mul.Scalar %3908, %int32_4758 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4759 = torch.constant.int 1 - %3920 = torch.aten.add.Tensor %3919, %3916, %int1_4759 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_4760 = torch.constant.int 2 - %3921 = torch.aten.mul.Scalar %3920, %int2_4760 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4761 = torch.constant.int 1 - %3922 = torch.aten.add.Tensor %3921, %3918, %int1_4761 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_4762 = torch.constant.int 32 - %3923 = torch.aten.mul.Scalar %3922, %int32_4762 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4763 = torch.constant.int 1 - %3924 = torch.aten.add.Tensor %3923, %3910, %int1_4763 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_4764 = torch.constant.int 32 - %int2_4765 = torch.constant.int 2 - %int32_4766 = torch.constant.int 32 - %int8_4767 = torch.constant.int 8 - %int128_4768 = torch.constant.int 128 - %3925 = torch.prim.ListConstruct %437, %int32_4764, %int2_4765, %int32_4766, %int8_4767, %int128_4768 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3926 = torch.aten.view %3762, %3925 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3926, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4769 = torch.constant.int 32 - %3927 = torch.aten.mul.int %437, %int32_4769 : !torch.int, !torch.int -> !torch.int - %int2_4770 = torch.constant.int 2 - %3928 = torch.aten.mul.int %3927, %int2_4770 : !torch.int, !torch.int -> !torch.int - %int32_4771 = torch.constant.int 32 - %3929 = torch.aten.mul.int %3928, %int32_4771 : !torch.int, !torch.int -> !torch.int - %int8_4772 = torch.constant.int 8 - %int128_4773 = torch.constant.int 128 - %3930 = torch.prim.ListConstruct %3929, %int8_4772, %int128_4773 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3931 = torch.aten.view %3926, %3930 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3931, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %3932 = torch.prim.ListConstruct %3924 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_4774 = torch.constant.bool false - %3933 = torch.aten.index_put %3931, %3932, %3905, %false_4774 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3933, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_4775 = torch.constant.int 32 - %int2_4776 = torch.constant.int 2 - %int32_4777 = torch.constant.int 32 - %int8_4778 = torch.constant.int 8 - %int128_4779 = torch.constant.int 128 - %3934 = torch.prim.ListConstruct %437, %int32_4775, %int2_4776, %int32_4777, %int8_4778, %int128_4779 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3935 = torch.aten.view %3933, %3934 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3935, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_4780 = torch.constant.int 2097152 - %3936 = torch.prim.ListConstruct %437, %int2097152_4780 : (!torch.int, !torch.int) -> !torch.list - %3937 = torch.aten.view %3935, %3936 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3937, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_4781 = torch.constant.int 32 - %int2_4782 = torch.constant.int 2 - %int32_4783 = torch.constant.int 32 - %int8_4784 = torch.constant.int 8 - %int128_4785 = torch.constant.int 128 - %3938 = torch.prim.ListConstruct %437, %int32_4781, %int2_4782, %int32_4783, %int8_4784, %int128_4785 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3939 = torch.aten.view %3937, %3938 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3939, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_4786 = torch.constant.int 8 - %int128_4787 = torch.constant.int 128 - %3940 = torch.prim.ListConstruct %3929, %int8_4786, %int128_4787 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %3941 = torch.aten.view %3939, %3940 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3941, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_4788 = torch.constant.int 32 - %3942 = torch.aten.floor_divide.Scalar %arg2, %int32_4788 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_4789 = torch.constant.int 1 - %3943 = torch.aten.unsqueeze %3942, %int1_4789 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4790 = torch.constant.int 1 - %false_4791 = torch.constant.bool false - %3944 = torch.aten.gather %arg3, %int1_4790, %3943, %false_4791 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_4792 = torch.constant.int 32 - %3945 = torch.aten.remainder.Scalar %arg2, %int32_4792 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_4793 = torch.constant.int 1 - %3946 = torch.aten.unsqueeze %3945, %int1_4793 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_4794 = torch.constant.none - %3947 = torch.aten.clone %193, %none_4794 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_4795 = torch.constant.int 0 - %3948 = torch.aten.unsqueeze %3947, %int0_4795 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_4796 = torch.constant.int 4 - %int1_4797 = torch.constant.int 1 - %3949 = torch.prim.ListConstruct %int4_4796, %int1_4797 : (!torch.int, !torch.int) -> !torch.list - %int1_4798 = torch.constant.int 1 + %int4096_4752 = torch.constant.int 4096 + %4582 = torch.prim.ListConstruct %int4_4750, %int1_4751, %int4096_4752 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4583 = torch.aten.view %4581, %4582 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_4753 = torch.constant.int 1 + %4584 = torch.aten.add.Tensor %4550, %4583, %int1_4753 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_4754 = torch.constant.int 6 + %4585 = torch.prims.convert_element_type %4584, %int6_4754 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_4755 = torch.constant.int 2 + %4586 = torch.aten.pow.Tensor_Scalar %4585, %int2_4755 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_4756 = torch.constant.int -1 + %4587 = torch.prim.ListConstruct %int-1_4756 : (!torch.int) -> !torch.list + %true_4757 = torch.constant.bool true + %none_4758 = torch.constant.none + %4588 = torch.aten.mean.dim %4586, %4587, %true_4757, %none_4758 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_4759 = torch.constant.float 9.9999997473787516E-6 + %int1_4760 = torch.constant.int 1 + %4589 = torch.aten.add.Scalar %4588, %float9.999990e-06_4759, %int1_4760 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %4590 = torch.aten.rsqrt %4589 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %4591 = torch.aten.mul.Tensor %4585, %4590 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_4761 = torch.constant.int 5 + %4592 = torch.prims.convert_element_type %4591, %int5_4761 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %4593 = torch.aten.mul.Tensor %268, %4592 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_4762 = torch.constant.int 5 + %4594 = torch.prims.convert_element_type %4593, %int5_4762 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_4763 = torch.constant.int -2 + %int-1_4764 = torch.constant.int -1 + %4595 = torch.aten.transpose.int %269, %int-2_4763, %int-1_4764 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_4765 = torch.constant.int 5 + %4596 = torch.prims.convert_element_type %4595, %int5_4765 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_4766 = torch.constant.int 4 + %int4096_4767 = torch.constant.int 4096 + %4597 = torch.prim.ListConstruct %int4_4766, %int4096_4767 : (!torch.int, !torch.int) -> !torch.list + %4598 = torch.aten.view %4594, %4597 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4599 = torch.aten.mm %4598, %4596 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_4768 = torch.constant.int 4 + %int1_4769 = torch.constant.int 1 + %int4096_4770 = torch.constant.int 4096 + %4600 = torch.prim.ListConstruct %int4_4768, %int1_4769, %int4096_4770 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4601 = torch.aten.view %4599, %4600 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_4771 = torch.constant.int -2 + %int-1_4772 = torch.constant.int -1 + %4602 = torch.aten.transpose.int %270, %int-2_4771, %int-1_4772 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_4773 = torch.constant.int 5 + %4603 = torch.prims.convert_element_type %4602, %int5_4773 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_4774 = torch.constant.int 4 + %int4096_4775 = torch.constant.int 4096 + %4604 = torch.prim.ListConstruct %int4_4774, %int4096_4775 : (!torch.int, !torch.int) -> !torch.list + %4605 = torch.aten.view %4594, %4604 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4606 = torch.aten.mm %4605, %4603 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_4776 = torch.constant.int 4 + %int1_4777 = torch.constant.int 1 + %int1024_4778 = torch.constant.int 1024 + %4607 = torch.prim.ListConstruct %int4_4776, %int1_4777, %int1024_4778 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4608 = torch.aten.view %4606, %4607 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_4779 = torch.constant.int -2 + %int-1_4780 = torch.constant.int -1 + %4609 = torch.aten.transpose.int %271, %int-2_4779, %int-1_4780 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_4781 = torch.constant.int 5 + %4610 = torch.prims.convert_element_type %4609, %int5_4781 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_4782 = torch.constant.int 4 + %int4096_4783 = torch.constant.int 4096 + %4611 = torch.prim.ListConstruct %int4_4782, %int4096_4783 : (!torch.int, !torch.int) -> !torch.list + %4612 = torch.aten.view %4594, %4611 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4613 = torch.aten.mm %4612, %4610 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_4784 = torch.constant.int 4 + %int1_4785 = torch.constant.int 1 + %int1024_4786 = torch.constant.int 1024 + %4614 = torch.prim.ListConstruct %int4_4784, %int1_4785, %int1024_4786 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4615 = torch.aten.view %4613, %4614 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_4787 = torch.constant.int 4 + %int1_4788 = torch.constant.int 1 + %int32_4789 = torch.constant.int 32 + %int128_4790 = torch.constant.int 128 + %4616 = torch.prim.ListConstruct %int4_4787, %int1_4788, %int32_4789, %int128_4790 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4617 = torch.aten.view %4601, %4616 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_4791 = torch.constant.int 4 + %int1_4792 = torch.constant.int 1 + %int8_4793 = torch.constant.int 8 + %int128_4794 = torch.constant.int 128 + %4618 = torch.prim.ListConstruct %int4_4791, %int1_4792, %int8_4793, %int128_4794 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4619 = torch.aten.view %4608, %4618 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_4795 = torch.constant.int 4 + %int1_4796 = torch.constant.int 1 + %int8_4797 = torch.constant.int 8 + %int128_4798 = torch.constant.int 128 + %4620 = torch.prim.ListConstruct %int4_4795, %int1_4796, %int8_4797, %int128_4798 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4621 = torch.aten.view %4615, %4620 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> %int1_4799 = torch.constant.int 1 - %3950 = torch.prim.ListConstruct %int1_4798, %int1_4799 : (!torch.int, !torch.int) -> !torch.list - %int4_4800 = torch.constant.int 4 - %int0_4801 = torch.constant.int 0 - %cpu_4802 = torch.constant.device "cpu" - %false_4803 = torch.constant.bool false - %3951 = torch.aten.empty_strided %3949, %3950, %int4_4800, %int0_4801, %cpu_4802, %false_4803 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int17_4804 = torch.constant.int 17 - %3952 = torch.aten.fill.Scalar %3951, %int17_4804 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_4805 = torch.constant.int 4 - %int1_4806 = torch.constant.int 1 - %3953 = torch.prim.ListConstruct %int4_4805, %int1_4806 : (!torch.int, !torch.int) -> !torch.list - %3954 = torch.aten.repeat %3948, %3953 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_4807 = torch.constant.int 32 - %3955 = torch.aten.mul.Scalar %3944, %int32_4807 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int2_4800 = torch.constant.int 2 + %4622 = torch.aten.transpose.int %4617, %int1_4799, %int2_4800 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %4623 = torch.aten.mul.Tensor %4622, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_4801 = torch.constant.int 3 + %int0_4802 = torch.constant.int 0 + %int64_4803 = torch.constant.int 64 + %int1_4804 = torch.constant.int 1 + %4624 = torch.aten.slice.Tensor %4622, %int3_4801, %int0_4802, %int64_4803, %int1_4804 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_4805 = torch.constant.int 3 + %int64_4806 = torch.constant.int 64 + %int9223372036854775807_4807 = torch.constant.int 9223372036854775807 %int1_4808 = torch.constant.int 1 - %3956 = torch.aten.add.Tensor %3955, %3952, %int1_4808 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_4809 = torch.constant.int 2 - %3957 = torch.aten.mul.Scalar %3956, %int2_4809 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %4625 = torch.aten.slice.Tensor %4622, %int3_4805, %int64_4806, %int9223372036854775807_4807, %int1_4808 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %4626 = torch.aten.neg %4625 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %4627 = torch.prim.ListConstruct %4626, %4624 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_4809 = torch.constant.int -1 + %4628 = torch.aten.cat %4627, %int-1_4809 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %4629 = torch.aten.mul.Tensor %4628, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> %int1_4810 = torch.constant.int 1 - %3958 = torch.aten.add.Tensor %3957, %3954, %int1_4810 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_4811 = torch.constant.int 32 - %3959 = torch.aten.mul.Scalar %3958, %int32_4811 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_4812 = torch.constant.int 1 - %3960 = torch.aten.add.Tensor %3959, %3946, %int1_4812 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %3961 = torch.prim.ListConstruct %3960 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_4813 = torch.constant.bool false - %3962 = torch.aten.index_put %3941, %3961, %3893, %false_4813 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %3962, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_4814 = torch.constant.int 32 - %int2_4815 = torch.constant.int 2 - %int32_4816 = torch.constant.int 32 - %int8_4817 = torch.constant.int 8 - %int128_4818 = torch.constant.int 128 - %3963 = torch.prim.ListConstruct %437, %int32_4814, %int2_4815, %int32_4816, %int8_4817, %int128_4818 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3964 = torch.aten.view %3962, %3963 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3964, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_4819 = torch.constant.int 2097152 - %3965 = torch.prim.ListConstruct %437, %int2097152_4819 : (!torch.int, !torch.int) -> !torch.list - %3966 = torch.aten.view %3964, %3965 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %3966, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_4820 = torch.constant.int 4 - %3967 = torch.prim.ListConstruct %int4_4820, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_4821 = torch.constant.int 1 - %3968 = torch.prim.ListConstruct %358, %int1_4821 : (!torch.int, !torch.int) -> !torch.list - %int4_4822 = torch.constant.int 4 - %int0_4823 = torch.constant.int 0 - %cpu_4824 = torch.constant.device "cpu" - %false_4825 = torch.constant.bool false - %3969 = torch.aten.empty_strided %3967, %3968, %int4_4822, %int0_4823, %cpu_4824, %false_4825 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3969, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int17_4826 = torch.constant.int 17 - %3970 = torch.aten.fill.Scalar %3969, %int17_4826 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3970, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %4630 = torch.aten.add.Tensor %4623, %4629, %int1_4810 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_4811 = torch.constant.int 1 + %int2_4812 = torch.constant.int 2 + %4631 = torch.aten.transpose.int %4630, %int1_4811, %int2_4812 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_4813 = torch.constant.int 1 + %int2_4814 = torch.constant.int 2 + %4632 = torch.aten.transpose.int %4619, %int1_4813, %int2_4814 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %4633 = torch.aten.mul.Tensor %4632, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_4815 = torch.constant.int 3 + %int0_4816 = torch.constant.int 0 + %int64_4817 = torch.constant.int 64 + %int1_4818 = torch.constant.int 1 + %4634 = torch.aten.slice.Tensor %4632, %int3_4815, %int0_4816, %int64_4817, %int1_4818 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_4819 = torch.constant.int 3 + %int64_4820 = torch.constant.int 64 + %int9223372036854775807_4821 = torch.constant.int 9223372036854775807 + %int1_4822 = torch.constant.int 1 + %4635 = torch.aten.slice.Tensor %4632, %int3_4819, %int64_4820, %int9223372036854775807_4821, %int1_4822 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %4636 = torch.aten.neg %4635 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %4637 = torch.prim.ListConstruct %4636, %4634 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_4823 = torch.constant.int -1 + %4638 = torch.aten.cat %4637, %int-1_4823 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %4639 = torch.aten.mul.Tensor %4638, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_4824 = torch.constant.int 1 + %4640 = torch.aten.add.Tensor %4633, %4639, %int1_4824 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_4825 = torch.constant.int 1 + %int2_4826 = torch.constant.int 2 + %4641 = torch.aten.transpose.int %4640, %int1_4825, %int2_4826 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> %int32_4827 = torch.constant.int 32 - %3971 = torch.aten.mul.Scalar %arg3, %int32_4827 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3971, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %4642 = torch.aten.floor_divide.Scalar %arg2, %int32_4827 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> %int1_4828 = torch.constant.int 1 - %3972 = torch.aten.add.Tensor %3971, %3970, %int1_4828 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %3972, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_4829 = torch.constant.int 4 - %3973 = torch.aten.mul.int %int4_4829, %358 : !torch.int, !torch.int -> !torch.int - %3974 = torch.prim.ListConstruct %3973 : (!torch.int) -> !torch.list - %3975 = torch.aten.view %3972, %3974 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %3975, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_4830 = torch.constant.int 32 - %int2_4831 = torch.constant.int 2 - %int32_4832 = torch.constant.int 32 - %int8_4833 = torch.constant.int 8 - %int128_4834 = torch.constant.int 128 - %3976 = torch.prim.ListConstruct %437, %int32_4830, %int2_4831, %int32_4832, %int8_4833, %int128_4834 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3977 = torch.aten.view %3966, %3976 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %3977, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_4835 = torch.constant.int 32 - %3978 = torch.aten.mul.int %437, %int32_4835 : !torch.int, !torch.int -> !torch.int - %int2_4836 = torch.constant.int 2 - %int32_4837 = torch.constant.int 32 + %4643 = torch.aten.unsqueeze %4642, %int1_4828 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_4829 = torch.constant.int 1 + %false_4830 = torch.constant.bool false + %4644 = torch.aten.gather %arg3, %int1_4829, %4643, %false_4830 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_4831 = torch.constant.int 4 + %int1_4832 = torch.constant.int 1 + %int1_4833 = torch.constant.int 1 + %4645 = torch.prim.ListConstruct %int4_4831, %int1_4832, %int1_4833 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4646 = torch.aten.view %4644, %4645 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_4834 = torch.constant.int 32 + %4647 = torch.aten.remainder.Scalar %arg2, %int32_4834 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_4835 = torch.constant.int 4 + %int1_4836 = torch.constant.int 1 + %int1_4837 = torch.constant.int 1 + %4648 = torch.prim.ListConstruct %int4_4835, %int1_4836, %int1_4837 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4649 = torch.aten.view %4647, %4648 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> %int8_4838 = torch.constant.int 8 - %int128_4839 = torch.constant.int 128 - %3979 = torch.prim.ListConstruct %3978, %int2_4836, %int32_4837, %int8_4838, %int128_4839 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3980 = torch.aten.view %3977, %3979 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %3980, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_4840 = torch.constant.int 0 - %3981 = torch.aten.index_select %3980, %int0_4840, %3975 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %3981, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_4841 = torch.constant.int 4 - %int2_4842 = torch.constant.int 2 - %int32_4843 = torch.constant.int 32 - %int8_4844 = torch.constant.int 8 - %int128_4845 = torch.constant.int 128 - %3982 = torch.prim.ListConstruct %int4_4841, %358, %int2_4842, %int32_4843, %int8_4844, %int128_4845 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3983 = torch.aten.view %3981, %3982 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3983, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_4846 = torch.constant.int 0 - %int0_4847 = torch.constant.int 0 - %int9223372036854775807_4848 = torch.constant.int 9223372036854775807 + %none_4839 = torch.constant.none + %none_4840 = torch.constant.none + %cpu_4841 = torch.constant.device "cpu" + %false_4842 = torch.constant.bool false + %4650 = torch.aten.arange %int8_4838, %none_4839, %none_4840, %cpu_4841, %false_4842 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_4843 = torch.constant.int 1 + %int1_4844 = torch.constant.int 1 + %int8_4845 = torch.constant.int 8 + %4651 = torch.prim.ListConstruct %int1_4843, %int1_4844, %int8_4845 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4652 = torch.aten.view %4650, %4651 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_4846 = torch.constant.none + %4653 = torch.aten.clone %272, %none_4846 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4654 = torch.aten.detach %4653 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4655 = torch.aten.detach %4654 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4656 = torch.aten.detach %4655 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_4847 = torch.constant.int 1 + %int1_4848 = torch.constant.int 1 %int1_4849 = torch.constant.int 1 - %3984 = torch.aten.slice.Tensor %3983, %int0_4846, %int0_4847, %int9223372036854775807_4848, %int1_4849 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3984, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_4850 = torch.constant.int 1 - %int0_4851 = torch.constant.int 0 - %int9223372036854775807_4852 = torch.constant.int 9223372036854775807 + %4657 = torch.prim.ListConstruct %int1_4847, %int1_4848, %int1_4849 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4658 = torch.aten.view %4656, %4657 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_4850 = torch.constant.int 32 + %4659 = torch.aten.mul.Scalar %4646, %int32_4850 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int19 = torch.constant.int 19 + %int1_4851 = torch.constant.int 1 + %4660 = torch.aten.add.Scalar %4659, %int19, %int1_4851 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_4852 = torch.constant.int 2 + %4661 = torch.aten.mul.Scalar %4660, %int2_4852 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_4853 = torch.constant.int 1 - %3985 = torch.aten.slice.Tensor %3984, %int1_4850, %int0_4851, %int9223372036854775807_4852, %int1_4853 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3985, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_4854 = torch.constant.int 2 - %int0_4855 = torch.constant.int 0 - %3986 = torch.aten.select.int %3985, %int2_4854, %int0_4855 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3986, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %4662 = torch.aten.add.Tensor %4661, %4658, %int1_4853 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_4854 = torch.constant.int 8 + %4663 = torch.aten.mul.Scalar %4662, %int8_4854 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4855 = torch.constant.int 1 + %4664 = torch.aten.add.Tensor %4663, %4652, %int1_4855 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int32_4856 = torch.constant.int 32 - %3987 = torch.aten.mul.int %358, %int32_4856 : !torch.int, !torch.int -> !torch.int - %int2_4857 = torch.constant.int 2 - %int0_4858 = torch.constant.int 0 - %int1_4859 = torch.constant.int 1 - %3988 = torch.aten.slice.Tensor %3986, %int2_4857, %int0_4858, %3987, %int1_4859 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3988, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_4860 = torch.constant.int 0 - %3989 = torch.aten.clone %3988, %int0_4860 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3989, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_4861 = torch.constant.int 1 - %3990 = torch.aten.size.int %3985, %int1_4861 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int + %4665 = torch.aten.mul.Scalar %4664, %int32_4856 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_4857 = torch.constant.int 1 + %4666 = torch.aten.add.Tensor %4665, %4649, %int1_4857 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_4858 = torch.constant.int 5 + %4667 = torch.prims.convert_element_type %4641, %int5_4858 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_4859 = torch.constant.int 32 + %int2_4860 = torch.constant.int 2 + %int8_4861 = torch.constant.int 8 %int32_4862 = torch.constant.int 32 - %3991 = torch.aten.mul.int %3990, %int32_4862 : !torch.int, !torch.int -> !torch.int - %int4_4863 = torch.constant.int 4 - %int8_4864 = torch.constant.int 8 - %int128_4865 = torch.constant.int 128 - %3992 = torch.prim.ListConstruct %int4_4863, %3991, %int8_4864, %int128_4865 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %3993 = torch.aten._unsafe_view %3989, %3992 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3993, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_4866 = torch.constant.int 0 - %int0_4867 = torch.constant.int 0 - %int9223372036854775807_4868 = torch.constant.int 9223372036854775807 - %int1_4869 = torch.constant.int 1 - %3994 = torch.aten.slice.Tensor %3993, %int0_4866, %int0_4867, %int9223372036854775807_4868, %int1_4869 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %3994, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_4870 = torch.constant.int 0 - %int0_4871 = torch.constant.int 0 - %int9223372036854775807_4872 = torch.constant.int 9223372036854775807 - %int1_4873 = torch.constant.int 1 - %3995 = torch.aten.slice.Tensor %3983, %int0_4870, %int0_4871, %int9223372036854775807_4872, %int1_4873 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3995, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_4874 = torch.constant.int 1 - %int0_4875 = torch.constant.int 0 - %int9223372036854775807_4876 = torch.constant.int 9223372036854775807 - %int1_4877 = torch.constant.int 1 - %3996 = torch.aten.slice.Tensor %3995, %int1_4874, %int0_4875, %int9223372036854775807_4876, %int1_4877 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %3996, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_4878 = torch.constant.int 2 + %int128_4863 = torch.constant.int 128 + %4668 = torch.prim.ListConstruct %456, %int32_4859, %int2_4860, %int8_4861, %int32_4862, %int128_4863 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4669 = torch.aten.view %4489, %4668 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4669, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_4864 = torch.constant.int 128 + %4670 = torch.prim.ListConstruct %596, %int128_4864 : (!torch.int, !torch.int) -> !torch.list + %4671 = torch.aten.view %4669, %4670 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4671, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %4672 = torch.prim.ListConstruct %4666 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_4865 = torch.constant.bool false + %4673 = torch.aten.index_put %4671, %4672, %4667, %false_4865 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4673, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_4866 = torch.constant.int 32 + %int2_4867 = torch.constant.int 2 + %int8_4868 = torch.constant.int 8 + %int32_4869 = torch.constant.int 32 + %int128_4870 = torch.constant.int 128 + %4674 = torch.prim.ListConstruct %456, %int32_4866, %int2_4867, %int8_4868, %int32_4869, %int128_4870 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4675 = torch.aten.view %4673, %4674 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4675, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_4871 = torch.constant.int 2097152 + %4676 = torch.prim.ListConstruct %456, %int2097152_4871 : (!torch.int, !torch.int) -> !torch.list + %4677 = torch.aten.view %4675, %4676 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4677, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_4872 = torch.constant.int 32 + %int2_4873 = torch.constant.int 2 + %int8_4874 = torch.constant.int 8 + %int32_4875 = torch.constant.int 32 + %int128_4876 = torch.constant.int 128 + %4678 = torch.prim.ListConstruct %456, %int32_4872, %int2_4873, %int8_4874, %int32_4875, %int128_4876 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4679 = torch.aten.view %4677, %4678 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4679, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_4877 = torch.constant.int 128 + %4680 = torch.prim.ListConstruct %596, %int128_4877 : (!torch.int, !torch.int) -> !torch.list + %4681 = torch.aten.view %4679, %4680 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4681, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_4878 = torch.constant.none + %4682 = torch.aten.clone %273, %none_4878 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4683 = torch.aten.detach %4682 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4684 = torch.aten.detach %4683 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4685 = torch.aten.detach %4684 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %int1_4879 = torch.constant.int 1 - %3997 = torch.aten.select.int %3996, %int2_4878, %int1_4879 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3997, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_4880 = torch.constant.int 2 - %int0_4881 = torch.constant.int 0 - %int1_4882 = torch.constant.int 1 - %3998 = torch.aten.slice.Tensor %3997, %int2_4880, %int0_4881, %3987, %int1_4882 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3998, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_4883 = torch.constant.int 0 - %3999 = torch.aten.clone %3998, %int0_4883 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %3999, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int1_4880 = torch.constant.int 1 + %int1_4881 = torch.constant.int 1 + %4686 = torch.prim.ListConstruct %int1_4879, %int1_4880, %int1_4881 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4687 = torch.aten.view %4685, %4686 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_4882 = torch.constant.int 32 + %4688 = torch.aten.mul.Scalar %4646, %int32_4882 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int19_4883 = torch.constant.int 19 %int1_4884 = torch.constant.int 1 - %4000 = torch.aten.size.int %3996, %int1_4884 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_4885 = torch.constant.int 32 - %4001 = torch.aten.mul.int %4000, %int32_4885 : !torch.int, !torch.int -> !torch.int - %int4_4886 = torch.constant.int 4 + %4689 = torch.aten.add.Scalar %4688, %int19_4883, %int1_4884 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_4885 = torch.constant.int 2 + %4690 = torch.aten.mul.Scalar %4689, %int2_4885 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4886 = torch.constant.int 1 + %4691 = torch.aten.add.Tensor %4690, %4687, %int1_4886 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int8_4887 = torch.constant.int 8 - %int128_4888 = torch.constant.int 128 - %4002 = torch.prim.ListConstruct %int4_4886, %4001, %int8_4887, %int128_4888 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4003 = torch.aten._unsafe_view %3999, %4002 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4003, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_4889 = torch.constant.int 0 - %int0_4890 = torch.constant.int 0 - %int9223372036854775807_4891 = torch.constant.int 9223372036854775807 - %int1_4892 = torch.constant.int 1 - %4004 = torch.aten.slice.Tensor %4003, %int0_4889, %int0_4890, %int9223372036854775807_4891, %int1_4892 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4004, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_4893 = torch.constant.int -2 - %4005 = torch.aten.unsqueeze %3994, %int-2_4893 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4005, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_4894 = torch.constant.int 1 - %4006 = torch.aten.size.int %3993, %int1_4894 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_4895 = torch.constant.int 4 - %int8_4896 = torch.constant.int 8 - %int4_4897 = torch.constant.int 4 - %int128_4898 = torch.constant.int 128 - %4007 = torch.prim.ListConstruct %int4_4895, %4006, %int8_4896, %int4_4897, %int128_4898 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4899 = torch.constant.bool false - %4008 = torch.aten.expand %4005, %4007, %false_4899 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4008, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_4900 = torch.constant.int 0 - %4009 = torch.aten.clone %4008, %int0_4900 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4009, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4901 = torch.constant.int 4 + %4692 = torch.aten.mul.Scalar %4691, %int8_4887 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_4888 = torch.constant.int 1 + %4693 = torch.aten.add.Tensor %4692, %4652, %int1_4888 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_4889 = torch.constant.int 32 + %4694 = torch.aten.mul.Scalar %4693, %int32_4889 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_4890 = torch.constant.int 1 + %4695 = torch.aten.add.Tensor %4694, %4649, %int1_4890 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_4891 = torch.constant.int 5 + %4696 = torch.prims.convert_element_type %4621, %int5_4891 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %4697 = torch.prim.ListConstruct %4695 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_4892 = torch.constant.bool false + %4698 = torch.aten.index_put %4681, %4697, %4696, %false_4892 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4698, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_4893 = torch.constant.int 32 + %int2_4894 = torch.constant.int 2 + %int8_4895 = torch.constant.int 8 + %int32_4896 = torch.constant.int 32 + %int128_4897 = torch.constant.int 128 + %4699 = torch.prim.ListConstruct %456, %int32_4893, %int2_4894, %int8_4895, %int32_4896, %int128_4897 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4700 = torch.aten.view %4698, %4699 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4700, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_4898 = torch.constant.int 2097152 + %4701 = torch.prim.ListConstruct %456, %int2097152_4898 : (!torch.int, !torch.int) -> !torch.list + %4702 = torch.aten.view %4700, %4701 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4702, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_4899 = torch.constant.none + %4703 = torch.aten.clone %274, %none_4899 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4704 = torch.aten.detach %4703 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4705 = torch.aten.detach %4704 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4706 = torch.aten.detach %4705 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_4900 = torch.constant.none + %4707 = torch.aten.clone %275, %none_4900 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4708 = torch.aten.detach %4707 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4709 = torch.aten.detach %4708 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4710 = torch.aten.detach %4709 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_4901 = torch.constant.none + %4711 = torch.aten.clone %276, %none_4901 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4712 = torch.aten.detach %4711 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4713 = torch.aten.detach %4712 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4714 = torch.aten.detach %4713 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %int32_4902 = torch.constant.int 32 - %int128_4903 = torch.constant.int 128 - %4010 = torch.prim.ListConstruct %int4_4901, %4006, %int32_4902, %int128_4903 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4011 = torch.aten._unsafe_view %4009, %4010 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4011, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_4904 = torch.constant.int -2 - %4012 = torch.aten.unsqueeze %4004, %int-2_4904 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4012, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_4905 = torch.constant.int 1 - %4013 = torch.aten.size.int %4003, %int1_4905 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_4906 = torch.constant.int 4 - %int8_4907 = torch.constant.int 8 - %int4_4908 = torch.constant.int 4 - %int128_4909 = torch.constant.int 128 - %4014 = torch.prim.ListConstruct %int4_4906, %4013, %int8_4907, %int4_4908, %int128_4909 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_4910 = torch.constant.bool false - %4015 = torch.aten.expand %4012, %4014, %false_4910 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4015, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_4911 = torch.constant.int 0 - %4016 = torch.aten.clone %4015, %int0_4911 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4016, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_4912 = torch.constant.int 4 - %int32_4913 = torch.constant.int 32 - %int128_4914 = torch.constant.int 128 - %4017 = torch.prim.ListConstruct %int4_4912, %4013, %int32_4913, %int128_4914 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4018 = torch.aten._unsafe_view %4016, %4017 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4018, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_4915 = torch.constant.int 1 - %int2_4916 = torch.constant.int 2 - %4019 = torch.aten.transpose.int %3899, %int1_4915, %int2_4916 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_4917 = torch.constant.int 1 - %int2_4918 = torch.constant.int 2 - %4020 = torch.aten.transpose.int %4011, %int1_4917, %int2_4918 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4020, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_4919 = torch.constant.int 1 - %int2_4920 = torch.constant.int 2 - %4021 = torch.aten.transpose.int %4018, %int1_4919, %int2_4920 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4021, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_4921 = torch.constant.float 0.000000e+00 - %false_4922 = torch.constant.bool false - %none_4923 = torch.constant.none - %4022:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4019, %4020, %4021, %float0.000000e00_4921, %false_4922, %368, %none_4923) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_4924 = torch.constant.int 1 - %int2_4925 = torch.constant.int 2 - %4023 = torch.aten.transpose.int %4022#0, %int1_4924, %int2_4925 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int2_4903 = torch.constant.int 2 + %int8_4904 = torch.constant.int 8 + %int32_4905 = torch.constant.int 32 + %int128_4906 = torch.constant.int 128 + %4715 = torch.prim.ListConstruct %456, %int32_4902, %int2_4903, %int8_4904, %int32_4905, %int128_4906 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4716 = torch.aten.view %4702, %4715 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4716, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %4717 = torch_c.to_builtin_tensor %4716 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %4718 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_4907 = tensor.cast %4718 : tensor<4x?xi64> to tensor + %4719 = torch_c.to_builtin_tensor %4706 : !torch.vtensor<[],si64> -> tensor + %4720 = torch_c.to_builtin_tensor %4710 : !torch.vtensor<[],si64> -> tensor + %4721 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%4717, %cast_4907, %4719, %4720) : (tensor, tensor, tensor, tensor) -> tensor + %cast_4908 = tensor.cast %4721 : tensor to tensor<4x?x8x32x128xf16> + %4722 = torch_c.from_builtin_tensor %cast_4908 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %4722, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %4723 = torch_c.to_builtin_tensor %4716 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %4724 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_4909 = tensor.cast %4724 : tensor<4x?xi64> to tensor + %4725 = torch_c.to_builtin_tensor %4706 : !torch.vtensor<[],si64> -> tensor + %4726 = torch_c.to_builtin_tensor %4714 : !torch.vtensor<[],si64> -> tensor + %4727 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%4723, %cast_4909, %4725, %4726) : (tensor, tensor, tensor, tensor) -> tensor + %cast_4910 = tensor.cast %4727 : tensor to tensor<4x?x8x32x128xf16> + %4728 = torch_c.from_builtin_tensor %cast_4910 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %4728, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_4911 = torch.constant.int 2 + %int3_4912 = torch.constant.int 3 + %4729 = torch.aten.transpose.int %4722, %int2_4911, %int3_4912 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4729, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_4913 = torch.constant.int 0 + %4730 = torch.aten.clone %4729, %int0_4913 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4730, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_4914 = torch.constant.int 4 + %int8_4915 = torch.constant.int 8 + %int128_4916 = torch.constant.int 128 + %4731 = torch.prim.ListConstruct %int4_4914, %457, %int8_4915, %int128_4916 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4732 = torch.aten._unsafe_view %4730, %4731 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4732, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_4917 = torch.constant.int 2 + %int3_4918 = torch.constant.int 3 + %4733 = torch.aten.transpose.int %4728, %int2_4917, %int3_4918 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4733, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_4919 = torch.constant.int 0 + %4734 = torch.aten.clone %4733, %int0_4919 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4734, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_4920 = torch.constant.int 4 + %int8_4921 = torch.constant.int 8 + %int128_4922 = torch.constant.int 128 + %4735 = torch.prim.ListConstruct %int4_4920, %457, %int8_4921, %int128_4922 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4736 = torch.aten._unsafe_view %4734, %4735 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4736, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_4923 = torch.constant.int -2 + %4737 = torch.aten.unsqueeze %4732, %int-2_4923 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4737, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_4924 = torch.constant.int 4 + %int8_4925 = torch.constant.int 8 %int4_4926 = torch.constant.int 4 - %int1_4927 = torch.constant.int 1 - %int4096_4928 = torch.constant.int 4096 - %4024 = torch.prim.ListConstruct %int4_4926, %int1_4927, %int4096_4928 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4025 = torch.aten.view %4023, %4024 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_4929 = torch.constant.int -2 - %int-1_4930 = torch.constant.int -1 - %4026 = torch.aten.transpose.int %194, %int-2_4929, %int-1_4930 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_4931 = torch.constant.int 4 - %int4096_4932 = torch.constant.int 4096 - %4027 = torch.prim.ListConstruct %int4_4931, %int4096_4932 : (!torch.int, !torch.int) -> !torch.list - %4028 = torch.aten.view %4025, %4027 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4029 = torch.aten.mm %4028, %4026 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_4933 = torch.constant.int 4 - %int1_4934 = torch.constant.int 1 - %int4096_4935 = torch.constant.int 4096 - %4030 = torch.prim.ListConstruct %int4_4933, %int1_4934, %int4096_4935 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4031 = torch.aten.view %4029, %4030 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_4936 = torch.constant.int 1 - %4032 = torch.aten.add.Tensor %3859, %4031, %int1_4936 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_4937 = torch.constant.int 6 - %4033 = torch.prims.convert_element_type %4032, %int6_4937 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_4938 = torch.constant.int 2 - %4034 = torch.aten.pow.Tensor_Scalar %4033, %int2_4938 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_4939 = torch.constant.int -1 - %4035 = torch.prim.ListConstruct %int-1_4939 : (!torch.int) -> !torch.list - %true_4940 = torch.constant.bool true - %none_4941 = torch.constant.none - %4036 = torch.aten.mean.dim %4034, %4035, %true_4940, %none_4941 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_4942 = torch.constant.float 9.9999997473787516E-6 + %int128_4927 = torch.constant.int 128 + %4738 = torch.prim.ListConstruct %int4_4924, %457, %int8_4925, %int4_4926, %int128_4927 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_4928 = torch.constant.bool false + %4739 = torch.aten.expand %4737, %4738, %false_4928 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4739, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_4929 = torch.constant.int 0 + %4740 = torch.aten.clone %4739, %int0_4929 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4740, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_4930 = torch.constant.int 4 + %int32_4931 = torch.constant.int 32 + %int128_4932 = torch.constant.int 128 + %4741 = torch.prim.ListConstruct %int4_4930, %457, %int32_4931, %int128_4932 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4742 = torch.aten._unsafe_view %4740, %4741 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4742, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_4933 = torch.constant.int -2 + %4743 = torch.aten.unsqueeze %4736, %int-2_4933 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4743, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_4934 = torch.constant.int 4 + %int8_4935 = torch.constant.int 8 + %int4_4936 = torch.constant.int 4 + %int128_4937 = torch.constant.int 128 + %4744 = torch.prim.ListConstruct %int4_4934, %457, %int8_4935, %int4_4936, %int128_4937 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_4938 = torch.constant.bool false + %4745 = torch.aten.expand %4743, %4744, %false_4938 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4745, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_4939 = torch.constant.int 0 + %4746 = torch.aten.clone %4745, %int0_4939 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4746, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_4940 = torch.constant.int 4 + %int32_4941 = torch.constant.int 32 + %int128_4942 = torch.constant.int 128 + %4747 = torch.prim.ListConstruct %int4_4940, %457, %int32_4941, %int128_4942 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4748 = torch.aten._unsafe_view %4746, %4747 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4748, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int1_4943 = torch.constant.int 1 - %4037 = torch.aten.add.Scalar %4036, %float9.999990e-06_4942, %int1_4943 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %4038 = torch.aten.rsqrt %4037 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %4039 = torch.aten.mul.Tensor %4033, %4038 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_4944 = torch.constant.int 5 - %4040 = torch.prims.convert_element_type %4039, %int5_4944 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %4041 = torch.aten.mul.Tensor %195, %4040 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_4945 = torch.constant.int 5 - %4042 = torch.prims.convert_element_type %4041, %int5_4945 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_4946 = torch.constant.int -2 - %int-1_4947 = torch.constant.int -1 - %4043 = torch.aten.transpose.int %196, %int-2_4946, %int-1_4947 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_4948 = torch.constant.int 4 - %int4096_4949 = torch.constant.int 4096 - %4044 = torch.prim.ListConstruct %int4_4948, %int4096_4949 : (!torch.int, !torch.int) -> !torch.list - %4045 = torch.aten.view %4042, %4044 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4046 = torch.aten.mm %4045, %4043 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_4950 = torch.constant.int 4 - %int1_4951 = torch.constant.int 1 - %int14336_4952 = torch.constant.int 14336 - %4047 = torch.prim.ListConstruct %int4_4950, %int1_4951, %int14336_4952 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4048 = torch.aten.view %4046, %4047 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %4049 = torch.aten.silu %4048 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_4953 = torch.constant.int -2 - %int-1_4954 = torch.constant.int -1 - %4050 = torch.aten.transpose.int %197, %int-2_4953, %int-1_4954 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_4955 = torch.constant.int 4 + %int2_4944 = torch.constant.int 2 + %4749 = torch.aten.transpose.int %4631, %int1_4943, %int2_4944 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_4945 = torch.constant.int 1 + %int2_4946 = torch.constant.int 2 + %4750 = torch.aten.transpose.int %4742, %int1_4945, %int2_4946 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4750, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_4947 = torch.constant.int 1 + %int2_4948 = torch.constant.int 2 + %4751 = torch.aten.transpose.int %4748, %int1_4947, %int2_4948 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4751, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_4949 = torch.constant.float 0.000000e+00 + %false_4950 = torch.constant.bool false + %none_4951 = torch.constant.none + %4752:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4749, %4750, %4751, %float0.000000e00_4949, %false_4950, %470, %none_4951) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_4952 = torch.constant.int 1 + %int2_4953 = torch.constant.int 2 + %4753 = torch.aten.transpose.int %4752#0, %int1_4952, %int2_4953 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_4954 = torch.constant.int 4 + %int1_4955 = torch.constant.int 1 %int4096_4956 = torch.constant.int 4096 - %4051 = torch.prim.ListConstruct %int4_4955, %int4096_4956 : (!torch.int, !torch.int) -> !torch.list - %4052 = torch.aten.view %4042, %4051 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4053 = torch.aten.mm %4052, %4050 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_4957 = torch.constant.int 4 - %int1_4958 = torch.constant.int 1 - %int14336_4959 = torch.constant.int 14336 - %4054 = torch.prim.ListConstruct %int4_4957, %int1_4958, %int14336_4959 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4055 = torch.aten.view %4053, %4054 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %4056 = torch.aten.mul.Tensor %4049, %4055 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_4960 = torch.constant.int -2 - %int-1_4961 = torch.constant.int -1 - %4057 = torch.aten.transpose.int %198, %int-2_4960, %int-1_4961 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %4754 = torch.prim.ListConstruct %int4_4954, %int1_4955, %int4096_4956 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4755 = torch.aten.view %4753, %4754 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_4957 = torch.constant.int -2 + %int-1_4958 = torch.constant.int -1 + %4756 = torch.aten.transpose.int %277, %int-2_4957, %int-1_4958 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_4959 = torch.constant.int 5 + %4757 = torch.prims.convert_element_type %4756, %int5_4959 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_4960 = torch.constant.int 4 + %int4096_4961 = torch.constant.int 4096 + %4758 = torch.prim.ListConstruct %int4_4960, %int4096_4961 : (!torch.int, !torch.int) -> !torch.list + %4759 = torch.aten.view %4755, %4758 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4760 = torch.aten.mm %4759, %4757 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_4962 = torch.constant.int 4 - %int14336_4963 = torch.constant.int 14336 - %4058 = torch.prim.ListConstruct %int4_4962, %int14336_4963 : (!torch.int, !torch.int) -> !torch.list - %4059 = torch.aten.view %4056, %4058 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %4060 = torch.aten.mm %4059, %4057 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_4964 = torch.constant.int 4 + %int1_4963 = torch.constant.int 1 + %int4096_4964 = torch.constant.int 4096 + %4761 = torch.prim.ListConstruct %int4_4962, %int1_4963, %int4096_4964 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4762 = torch.aten.view %4760, %4761 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_4965 = torch.constant.int 1 - %int4096_4966 = torch.constant.int 4096 - %4061 = torch.prim.ListConstruct %int4_4964, %int1_4965, %int4096_4966 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4062 = torch.aten.view %4060, %4061 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_4967 = torch.constant.int 1 - %4063 = torch.aten.add.Tensor %4032, %4062, %int1_4967 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_4968 = torch.constant.int 6 - %4064 = torch.prims.convert_element_type %4063, %int6_4968 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_4969 = torch.constant.int 2 - %4065 = torch.aten.pow.Tensor_Scalar %4064, %int2_4969 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_4970 = torch.constant.int -1 - %4066 = torch.prim.ListConstruct %int-1_4970 : (!torch.int) -> !torch.list - %true_4971 = torch.constant.bool true - %none_4972 = torch.constant.none - %4067 = torch.aten.mean.dim %4065, %4066, %true_4971, %none_4972 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_4973 = torch.constant.float 9.9999997473787516E-6 - %int1_4974 = torch.constant.int 1 - %4068 = torch.aten.add.Scalar %4067, %float9.999990e-06_4973, %int1_4974 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %4069 = torch.aten.rsqrt %4068 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %4070 = torch.aten.mul.Tensor %4064, %4069 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_4975 = torch.constant.int 5 - %4071 = torch.prims.convert_element_type %4070, %int5_4975 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %4072 = torch.aten.mul.Tensor %199, %4071 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_4976 = torch.constant.int 5 - %4073 = torch.prims.convert_element_type %4072, %int5_4976 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_4977 = torch.constant.int -2 - %int-1_4978 = torch.constant.int -1 - %4074 = torch.aten.transpose.int %200, %int-2_4977, %int-1_4978 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_4979 = torch.constant.int 4 - %int4096_4980 = torch.constant.int 4096 - %4075 = torch.prim.ListConstruct %int4_4979, %int4096_4980 : (!torch.int, !torch.int) -> !torch.list - %4076 = torch.aten.view %4073, %4075 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4077 = torch.aten.mm %4076, %4074 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_4981 = torch.constant.int 4 - %int1_4982 = torch.constant.int 1 - %int4096_4983 = torch.constant.int 4096 - %4078 = torch.prim.ListConstruct %int4_4981, %int1_4982, %int4096_4983 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4079 = torch.aten.view %4077, %4078 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_4984 = torch.constant.int -2 - %int-1_4985 = torch.constant.int -1 - %4080 = torch.aten.transpose.int %201, %int-2_4984, %int-1_4985 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %4763 = torch.aten.add.Tensor %4584, %4762, %int1_4965 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_4966 = torch.constant.int 6 + %4764 = torch.prims.convert_element_type %4763, %int6_4966 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_4967 = torch.constant.int 2 + %4765 = torch.aten.pow.Tensor_Scalar %4764, %int2_4967 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_4968 = torch.constant.int -1 + %4766 = torch.prim.ListConstruct %int-1_4968 : (!torch.int) -> !torch.list + %true_4969 = torch.constant.bool true + %none_4970 = torch.constant.none + %4767 = torch.aten.mean.dim %4765, %4766, %true_4969, %none_4970 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_4971 = torch.constant.float 9.9999997473787516E-6 + %int1_4972 = torch.constant.int 1 + %4768 = torch.aten.add.Scalar %4767, %float9.999990e-06_4971, %int1_4972 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %4769 = torch.aten.rsqrt %4768 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %4770 = torch.aten.mul.Tensor %4764, %4769 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_4973 = torch.constant.int 5 + %4771 = torch.prims.convert_element_type %4770, %int5_4973 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %4772 = torch.aten.mul.Tensor %278, %4771 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_4974 = torch.constant.int 5 + %4773 = torch.prims.convert_element_type %4772, %int5_4974 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_4975 = torch.constant.int -2 + %int-1_4976 = torch.constant.int -1 + %4774 = torch.aten.transpose.int %279, %int-2_4975, %int-1_4976 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_4977 = torch.constant.int 5 + %4775 = torch.prims.convert_element_type %4774, %int5_4977 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_4978 = torch.constant.int 4 + %int4096_4979 = torch.constant.int 4096 + %4776 = torch.prim.ListConstruct %int4_4978, %int4096_4979 : (!torch.int, !torch.int) -> !torch.list + %4777 = torch.aten.view %4773, %4776 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4778 = torch.aten.mm %4777, %4775 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_4980 = torch.constant.int 4 + %int1_4981 = torch.constant.int 1 + %int14336_4982 = torch.constant.int 14336 + %4779 = torch.prim.ListConstruct %int4_4980, %int1_4981, %int14336_4982 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4780 = torch.aten.view %4778, %4779 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %4781 = torch.aten.silu %4780 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_4983 = torch.constant.int -2 + %int-1_4984 = torch.constant.int -1 + %4782 = torch.aten.transpose.int %280, %int-2_4983, %int-1_4984 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_4985 = torch.constant.int 5 + %4783 = torch.prims.convert_element_type %4782, %int5_4985 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> %int4_4986 = torch.constant.int 4 %int4096_4987 = torch.constant.int 4096 - %4081 = torch.prim.ListConstruct %int4_4986, %int4096_4987 : (!torch.int, !torch.int) -> !torch.list - %4082 = torch.aten.view %4073, %4081 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4083 = torch.aten.mm %4082, %4080 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %4784 = torch.prim.ListConstruct %int4_4986, %int4096_4987 : (!torch.int, !torch.int) -> !torch.list + %4785 = torch.aten.view %4773, %4784 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4786 = torch.aten.mm %4785, %4783 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> %int4_4988 = torch.constant.int 4 %int1_4989 = torch.constant.int 1 - %int1024_4990 = torch.constant.int 1024 - %4084 = torch.prim.ListConstruct %int4_4988, %int1_4989, %int1024_4990 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4085 = torch.aten.view %4083, %4084 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int14336_4990 = torch.constant.int 14336 + %4787 = torch.prim.ListConstruct %int4_4988, %int1_4989, %int14336_4990 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4788 = torch.aten.view %4786, %4787 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %4789 = torch.aten.mul.Tensor %4781, %4788 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> %int-2_4991 = torch.constant.int -2 %int-1_4992 = torch.constant.int -1 - %4086 = torch.aten.transpose.int %202, %int-2_4991, %int-1_4992 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_4993 = torch.constant.int 4 - %int4096_4994 = torch.constant.int 4096 - %4087 = torch.prim.ListConstruct %int4_4993, %int4096_4994 : (!torch.int, !torch.int) -> !torch.list - %4088 = torch.aten.view %4073, %4087 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4089 = torch.aten.mm %4088, %4086 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_4995 = torch.constant.int 4 - %int1_4996 = torch.constant.int 1 - %int1024_4997 = torch.constant.int 1024 - %4090 = torch.prim.ListConstruct %int4_4995, %int1_4996, %int1024_4997 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4091 = torch.aten.view %4089, %4090 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_4998 = torch.constant.int 4 + %4790 = torch.aten.transpose.int %281, %int-2_4991, %int-1_4992 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_4993 = torch.constant.int 5 + %4791 = torch.prims.convert_element_type %4790, %int5_4993 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_4994 = torch.constant.int 4 + %int14336_4995 = torch.constant.int 14336 + %4792 = torch.prim.ListConstruct %int4_4994, %int14336_4995 : (!torch.int, !torch.int) -> !torch.list + %4793 = torch.aten.view %4789, %4792 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %4794 = torch.aten.mm %4793, %4791 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_4996 = torch.constant.int 4 + %int1_4997 = torch.constant.int 1 + %int4096_4998 = torch.constant.int 4096 + %4795 = torch.prim.ListConstruct %int4_4996, %int1_4997, %int4096_4998 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4796 = torch.aten.view %4794, %4795 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_4999 = torch.constant.int 1 - %int32_5000 = torch.constant.int 32 - %int128_5001 = torch.constant.int 128 - %4092 = torch.prim.ListConstruct %int4_4998, %int1_4999, %int32_5000, %int128_5001 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4093 = torch.aten.view %4079, %4092 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_5002 = torch.constant.int 4 - %int1_5003 = torch.constant.int 1 - %int8_5004 = torch.constant.int 8 - %int128_5005 = torch.constant.int 128 - %4094 = torch.prim.ListConstruct %int4_5002, %int1_5003, %int8_5004, %int128_5005 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4095 = torch.aten.view %4085, %4094 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_5006 = torch.constant.int 4 - %int1_5007 = torch.constant.int 1 - %int8_5008 = torch.constant.int 8 - %int128_5009 = torch.constant.int 128 - %4096 = torch.prim.ListConstruct %int4_5006, %int1_5007, %int8_5008, %int128_5009 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4097 = torch.aten.view %4091, %4096 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_5010 = torch.constant.int 6 - %4098 = torch.prims.convert_element_type %4093, %int6_5010 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %4099 = torch_c.to_builtin_tensor %4098 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %4100 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %4101 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%4099, %4100) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %4102 = torch_c.from_builtin_tensor %4101 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> + %4797 = torch.aten.add.Tensor %4763, %4796, %int1_4999 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_5000 = torch.constant.int 6 + %4798 = torch.prims.convert_element_type %4797, %int6_5000 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_5001 = torch.constant.int 2 + %4799 = torch.aten.pow.Tensor_Scalar %4798, %int2_5001 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_5002 = torch.constant.int -1 + %4800 = torch.prim.ListConstruct %int-1_5002 : (!torch.int) -> !torch.list + %true_5003 = torch.constant.bool true + %none_5004 = torch.constant.none + %4801 = torch.aten.mean.dim %4799, %4800, %true_5003, %none_5004 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_5005 = torch.constant.float 9.9999997473787516E-6 + %int1_5006 = torch.constant.int 1 + %4802 = torch.aten.add.Scalar %4801, %float9.999990e-06_5005, %int1_5006 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %4803 = torch.aten.rsqrt %4802 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %4804 = torch.aten.mul.Tensor %4798, %4803 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_5007 = torch.constant.int 5 + %4805 = torch.prims.convert_element_type %4804, %int5_5007 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %4806 = torch.aten.mul.Tensor %282, %4805 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_5008 = torch.constant.int 5 + %4807 = torch.prims.convert_element_type %4806, %int5_5008 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_5009 = torch.constant.int -2 + %int-1_5010 = torch.constant.int -1 + %4808 = torch.aten.transpose.int %283, %int-2_5009, %int-1_5010 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> %int5_5011 = torch.constant.int 5 - %4103 = torch.prims.convert_element_type %4102, %int5_5011 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_5012 = torch.constant.int 6 - %4104 = torch.prims.convert_element_type %4095, %int6_5012 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %4105 = torch_c.to_builtin_tensor %4104 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %4106 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %4107 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%4105, %4106) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %4108 = torch_c.from_builtin_tensor %4107 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_5013 = torch.constant.int 5 - %4109 = torch.prims.convert_element_type %4108, %int5_5013 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_5014 = torch.constant.int 32 - %4110 = torch.aten.floor_divide.Scalar %arg2, %int32_5014 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %4809 = torch.prims.convert_element_type %4808, %int5_5011 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_5012 = torch.constant.int 4 + %int4096_5013 = torch.constant.int 4096 + %4810 = torch.prim.ListConstruct %int4_5012, %int4096_5013 : (!torch.int, !torch.int) -> !torch.list + %4811 = torch.aten.view %4807, %4810 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4812 = torch.aten.mm %4811, %4809 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_5014 = torch.constant.int 4 %int1_5015 = torch.constant.int 1 - %4111 = torch.aten.unsqueeze %4110, %int1_5015 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5016 = torch.constant.int 1 - %false_5017 = torch.constant.bool false - %4112 = torch.aten.gather %arg3, %int1_5016, %4111, %false_5017 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_5018 = torch.constant.int 32 - %4113 = torch.aten.remainder.Scalar %arg2, %int32_5018 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_5019 = torch.constant.int 1 - %4114 = torch.aten.unsqueeze %4113, %int1_5019 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_5020 = torch.constant.none - %4115 = torch.aten.clone %203, %none_5020 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_5021 = torch.constant.int 0 - %4116 = torch.aten.unsqueeze %4115, %int0_5021 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %int4096_5016 = torch.constant.int 4096 + %4813 = torch.prim.ListConstruct %int4_5014, %int1_5015, %int4096_5016 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4814 = torch.aten.view %4812, %4813 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_5017 = torch.constant.int -2 + %int-1_5018 = torch.constant.int -1 + %4815 = torch.aten.transpose.int %284, %int-2_5017, %int-1_5018 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_5019 = torch.constant.int 5 + %4816 = torch.prims.convert_element_type %4815, %int5_5019 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_5020 = torch.constant.int 4 + %int4096_5021 = torch.constant.int 4096 + %4817 = torch.prim.ListConstruct %int4_5020, %int4096_5021 : (!torch.int, !torch.int) -> !torch.list + %4818 = torch.aten.view %4807, %4817 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4819 = torch.aten.mm %4818, %4816 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> %int4_5022 = torch.constant.int 4 %int1_5023 = torch.constant.int 1 - %4117 = torch.prim.ListConstruct %int4_5022, %int1_5023 : (!torch.int, !torch.int) -> !torch.list - %int1_5024 = torch.constant.int 1 - %int1_5025 = torch.constant.int 1 - %4118 = torch.prim.ListConstruct %int1_5024, %int1_5025 : (!torch.int, !torch.int) -> !torch.list - %int4_5026 = torch.constant.int 4 - %int0_5027 = torch.constant.int 0 - %cpu_5028 = torch.constant.device "cpu" - %false_5029 = torch.constant.bool false - %4119 = torch.aten.empty_strided %4117, %4118, %int4_5026, %int0_5027, %cpu_5028, %false_5029 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int18 = torch.constant.int 18 - %4120 = torch.aten.fill.Scalar %4119, %int18 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1024_5024 = torch.constant.int 1024 + %4820 = torch.prim.ListConstruct %int4_5022, %int1_5023, %int1024_5024 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4821 = torch.aten.view %4819, %4820 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_5025 = torch.constant.int -2 + %int-1_5026 = torch.constant.int -1 + %4822 = torch.aten.transpose.int %285, %int-2_5025, %int-1_5026 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_5027 = torch.constant.int 5 + %4823 = torch.prims.convert_element_type %4822, %int5_5027 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_5028 = torch.constant.int 4 + %int4096_5029 = torch.constant.int 4096 + %4824 = torch.prim.ListConstruct %int4_5028, %int4096_5029 : (!torch.int, !torch.int) -> !torch.list + %4825 = torch.aten.view %4807, %4824 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4826 = torch.aten.mm %4825, %4823 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> %int4_5030 = torch.constant.int 4 %int1_5031 = torch.constant.int 1 - %4121 = torch.prim.ListConstruct %int4_5030, %int1_5031 : (!torch.int, !torch.int) -> !torch.list - %4122 = torch.aten.repeat %4116, %4121 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_5032 = torch.constant.int 32 - %4123 = torch.aten.mul.Scalar %4112, %int32_5032 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5033 = torch.constant.int 1 - %4124 = torch.aten.add.Tensor %4123, %4120, %int1_5033 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_5034 = torch.constant.int 2 - %4125 = torch.aten.mul.Scalar %4124, %int2_5034 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5035 = torch.constant.int 1 - %4126 = torch.aten.add.Tensor %4125, %4122, %int1_5035 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_5036 = torch.constant.int 32 - %4127 = torch.aten.mul.Scalar %4126, %int32_5036 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5037 = torch.constant.int 1 - %4128 = torch.aten.add.Tensor %4127, %4114, %int1_5037 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_5038 = torch.constant.int 32 - %int2_5039 = torch.constant.int 2 - %int32_5040 = torch.constant.int 32 - %int8_5041 = torch.constant.int 8 - %int128_5042 = torch.constant.int 128 - %4129 = torch.prim.ListConstruct %437, %int32_5038, %int2_5039, %int32_5040, %int8_5041, %int128_5042 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4130 = torch.aten.view %3966, %4129 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4130, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5043 = torch.constant.int 32 - %4131 = torch.aten.mul.int %437, %int32_5043 : !torch.int, !torch.int -> !torch.int - %int2_5044 = torch.constant.int 2 - %4132 = torch.aten.mul.int %4131, %int2_5044 : !torch.int, !torch.int -> !torch.int - %int32_5045 = torch.constant.int 32 - %4133 = torch.aten.mul.int %4132, %int32_5045 : !torch.int, !torch.int -> !torch.int - %int8_5046 = torch.constant.int 8 - %int128_5047 = torch.constant.int 128 - %4134 = torch.prim.ListConstruct %4133, %int8_5046, %int128_5047 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4135 = torch.aten.view %4130, %4134 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4135, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %4136 = torch.prim.ListConstruct %4128 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_5048 = torch.constant.bool false - %4137 = torch.aten.index_put %4135, %4136, %4109, %false_5048 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4137, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_5049 = torch.constant.int 32 - %int2_5050 = torch.constant.int 2 - %int32_5051 = torch.constant.int 32 - %int8_5052 = torch.constant.int 8 - %int128_5053 = torch.constant.int 128 - %4138 = torch.prim.ListConstruct %437, %int32_5049, %int2_5050, %int32_5051, %int8_5052, %int128_5053 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4139 = torch.aten.view %4137, %4138 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4139, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5054 = torch.constant.int 2097152 - %4140 = torch.prim.ListConstruct %437, %int2097152_5054 : (!torch.int, !torch.int) -> !torch.list - %4141 = torch.aten.view %4139, %4140 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4141, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_5055 = torch.constant.int 32 - %int2_5056 = torch.constant.int 2 - %int32_5057 = torch.constant.int 32 - %int8_5058 = torch.constant.int 8 - %int128_5059 = torch.constant.int 128 - %4142 = torch.prim.ListConstruct %437, %int32_5055, %int2_5056, %int32_5057, %int8_5058, %int128_5059 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4143 = torch.aten.view %4141, %4142 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4143, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_5060 = torch.constant.int 8 - %int128_5061 = torch.constant.int 128 - %4144 = torch.prim.ListConstruct %4133, %int8_5060, %int128_5061 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4145 = torch.aten.view %4143, %4144 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4145, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_5062 = torch.constant.int 32 - %4146 = torch.aten.floor_divide.Scalar %arg2, %int32_5062 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_5063 = torch.constant.int 1 - %4147 = torch.aten.unsqueeze %4146, %int1_5063 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1024_5032 = torch.constant.int 1024 + %4827 = torch.prim.ListConstruct %int4_5030, %int1_5031, %int1024_5032 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4828 = torch.aten.view %4826, %4827 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_5033 = torch.constant.int 4 + %int1_5034 = torch.constant.int 1 + %int32_5035 = torch.constant.int 32 + %int128_5036 = torch.constant.int 128 + %4829 = torch.prim.ListConstruct %int4_5033, %int1_5034, %int32_5035, %int128_5036 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4830 = torch.aten.view %4814, %4829 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_5037 = torch.constant.int 4 + %int1_5038 = torch.constant.int 1 + %int8_5039 = torch.constant.int 8 + %int128_5040 = torch.constant.int 128 + %4831 = torch.prim.ListConstruct %int4_5037, %int1_5038, %int8_5039, %int128_5040 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4832 = torch.aten.view %4821, %4831 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_5041 = torch.constant.int 4 + %int1_5042 = torch.constant.int 1 + %int8_5043 = torch.constant.int 8 + %int128_5044 = torch.constant.int 128 + %4833 = torch.prim.ListConstruct %int4_5041, %int1_5042, %int8_5043, %int128_5044 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4834 = torch.aten.view %4828, %4833 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_5045 = torch.constant.int 1 + %int2_5046 = torch.constant.int 2 + %4835 = torch.aten.transpose.int %4830, %int1_5045, %int2_5046 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %4836 = torch.aten.mul.Tensor %4835, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_5047 = torch.constant.int 3 + %int0_5048 = torch.constant.int 0 + %int64_5049 = torch.constant.int 64 + %int1_5050 = torch.constant.int 1 + %4837 = torch.aten.slice.Tensor %4835, %int3_5047, %int0_5048, %int64_5049, %int1_5050 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_5051 = torch.constant.int 3 + %int64_5052 = torch.constant.int 64 + %int9223372036854775807_5053 = torch.constant.int 9223372036854775807 + %int1_5054 = torch.constant.int 1 + %4838 = torch.aten.slice.Tensor %4835, %int3_5051, %int64_5052, %int9223372036854775807_5053, %int1_5054 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %4839 = torch.aten.neg %4838 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %4840 = torch.prim.ListConstruct %4839, %4837 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_5055 = torch.constant.int -1 + %4841 = torch.aten.cat %4840, %int-1_5055 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %4842 = torch.aten.mul.Tensor %4841, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_5056 = torch.constant.int 1 + %4843 = torch.aten.add.Tensor %4836, %4842, %int1_5056 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_5057 = torch.constant.int 1 + %int2_5058 = torch.constant.int 2 + %4844 = torch.aten.transpose.int %4843, %int1_5057, %int2_5058 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_5059 = torch.constant.int 1 + %int2_5060 = torch.constant.int 2 + %4845 = torch.aten.transpose.int %4832, %int1_5059, %int2_5060 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %4846 = torch.aten.mul.Tensor %4845, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_5061 = torch.constant.int 3 + %int0_5062 = torch.constant.int 0 + %int64_5063 = torch.constant.int 64 %int1_5064 = torch.constant.int 1 - %false_5065 = torch.constant.bool false - %4148 = torch.aten.gather %arg3, %int1_5064, %4147, %false_5065 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_5066 = torch.constant.int 32 - %4149 = torch.aten.remainder.Scalar %arg2, %int32_5066 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_5067 = torch.constant.int 1 - %4150 = torch.aten.unsqueeze %4149, %int1_5067 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_5068 = torch.constant.none - %4151 = torch.aten.clone %204, %none_5068 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_5069 = torch.constant.int 0 - %4152 = torch.aten.unsqueeze %4151, %int0_5069 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_5070 = torch.constant.int 4 + %4847 = torch.aten.slice.Tensor %4845, %int3_5061, %int0_5062, %int64_5063, %int1_5064 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_5065 = torch.constant.int 3 + %int64_5066 = torch.constant.int 64 + %int9223372036854775807_5067 = torch.constant.int 9223372036854775807 + %int1_5068 = torch.constant.int 1 + %4848 = torch.aten.slice.Tensor %4845, %int3_5065, %int64_5066, %int9223372036854775807_5067, %int1_5068 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %4849 = torch.aten.neg %4848 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %4850 = torch.prim.ListConstruct %4849, %4847 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_5069 = torch.constant.int -1 + %4851 = torch.aten.cat %4850, %int-1_5069 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %4852 = torch.aten.mul.Tensor %4851, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_5070 = torch.constant.int 1 + %4853 = torch.aten.add.Tensor %4846, %4852, %int1_5070 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> %int1_5071 = torch.constant.int 1 - %4153 = torch.prim.ListConstruct %int4_5070, %int1_5071 : (!torch.int, !torch.int) -> !torch.list - %int1_5072 = torch.constant.int 1 - %int1_5073 = torch.constant.int 1 - %4154 = torch.prim.ListConstruct %int1_5072, %int1_5073 : (!torch.int, !torch.int) -> !torch.list - %int4_5074 = torch.constant.int 4 - %int0_5075 = torch.constant.int 0 - %cpu_5076 = torch.constant.device "cpu" - %false_5077 = torch.constant.bool false - %4155 = torch.aten.empty_strided %4153, %4154, %int4_5074, %int0_5075, %cpu_5076, %false_5077 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int18_5078 = torch.constant.int 18 - %4156 = torch.aten.fill.Scalar %4155, %int18_5078 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_5079 = torch.constant.int 4 - %int1_5080 = torch.constant.int 1 - %4157 = torch.prim.ListConstruct %int4_5079, %int1_5080 : (!torch.int, !torch.int) -> !torch.list - %4158 = torch.aten.repeat %4152, %4157 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_5081 = torch.constant.int 32 - %4159 = torch.aten.mul.Scalar %4148, %int32_5081 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int2_5072 = torch.constant.int 2 + %4854 = torch.aten.transpose.int %4853, %int1_5071, %int2_5072 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_5073 = torch.constant.int 32 + %4855 = torch.aten.floor_divide.Scalar %arg2, %int32_5073 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_5074 = torch.constant.int 1 + %4856 = torch.aten.unsqueeze %4855, %int1_5074 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_5075 = torch.constant.int 1 + %false_5076 = torch.constant.bool false + %4857 = torch.aten.gather %arg3, %int1_5075, %4856, %false_5076 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_5077 = torch.constant.int 4 + %int1_5078 = torch.constant.int 1 + %int1_5079 = torch.constant.int 1 + %4858 = torch.prim.ListConstruct %int4_5077, %int1_5078, %int1_5079 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4859 = torch.aten.view %4857, %4858 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_5080 = torch.constant.int 32 + %4860 = torch.aten.remainder.Scalar %arg2, %int32_5080 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_5081 = torch.constant.int 4 %int1_5082 = torch.constant.int 1 - %4160 = torch.aten.add.Tensor %4159, %4156, %int1_5082 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_5083 = torch.constant.int 2 - %4161 = torch.aten.mul.Scalar %4160, %int2_5083 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5084 = torch.constant.int 1 - %4162 = torch.aten.add.Tensor %4161, %4158, %int1_5084 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_5085 = torch.constant.int 32 - %4163 = torch.aten.mul.Scalar %4162, %int32_5085 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5086 = torch.constant.int 1 - %4164 = torch.aten.add.Tensor %4163, %4150, %int1_5086 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %4165 = torch.prim.ListConstruct %4164 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_5087 = torch.constant.bool false - %4166 = torch.aten.index_put %4145, %4165, %4097, %false_5087 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4166, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_5088 = torch.constant.int 32 - %int2_5089 = torch.constant.int 2 - %int32_5090 = torch.constant.int 32 + %int1_5083 = torch.constant.int 1 + %4861 = torch.prim.ListConstruct %int4_5081, %int1_5082, %int1_5083 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4862 = torch.aten.view %4860, %4861 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_5084 = torch.constant.int 8 + %none_5085 = torch.constant.none + %none_5086 = torch.constant.none + %cpu_5087 = torch.constant.device "cpu" + %false_5088 = torch.constant.bool false + %4863 = torch.aten.arange %int8_5084, %none_5085, %none_5086, %cpu_5087, %false_5088 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_5089 = torch.constant.int 1 + %int1_5090 = torch.constant.int 1 %int8_5091 = torch.constant.int 8 - %int128_5092 = torch.constant.int 128 - %4167 = torch.prim.ListConstruct %437, %int32_5088, %int2_5089, %int32_5090, %int8_5091, %int128_5092 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4168 = torch.aten.view %4166, %4167 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4168, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5093 = torch.constant.int 2097152 - %4169 = torch.prim.ListConstruct %437, %int2097152_5093 : (!torch.int, !torch.int) -> !torch.list - %4170 = torch.aten.view %4168, %4169 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4170, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_5094 = torch.constant.int 4 - %4171 = torch.prim.ListConstruct %int4_5094, %358 : (!torch.int, !torch.int) -> !torch.list + %4864 = torch.prim.ListConstruct %int1_5089, %int1_5090, %int8_5091 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4865 = torch.aten.view %4863, %4864 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_5092 = torch.constant.none + %4866 = torch.aten.clone %286, %none_5092 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4867 = torch.aten.detach %4866 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4868 = torch.aten.detach %4867 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4869 = torch.aten.detach %4868 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_5093 = torch.constant.int 1 + %int1_5094 = torch.constant.int 1 %int1_5095 = torch.constant.int 1 - %4172 = torch.prim.ListConstruct %358, %int1_5095 : (!torch.int, !torch.int) -> !torch.list - %int4_5096 = torch.constant.int 4 - %int0_5097 = torch.constant.int 0 - %cpu_5098 = torch.constant.device "cpu" - %false_5099 = torch.constant.bool false - %4173 = torch.aten.empty_strided %4171, %4172, %int4_5096, %int0_5097, %cpu_5098, %false_5099 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4173, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int18_5100 = torch.constant.int 18 - %4174 = torch.aten.fill.Scalar %4173, %int18_5100 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4174, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_5101 = torch.constant.int 32 - %4175 = torch.aten.mul.Scalar %arg3, %int32_5101 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4175, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_5102 = torch.constant.int 1 - %4176 = torch.aten.add.Tensor %4175, %4174, %int1_5102 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4176, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_5103 = torch.constant.int 4 - %4177 = torch.aten.mul.int %int4_5103, %358 : !torch.int, !torch.int -> !torch.int - %4178 = torch.prim.ListConstruct %4177 : (!torch.int) -> !torch.list - %4179 = torch.aten.view %4176, %4178 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4179, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_5104 = torch.constant.int 32 - %int2_5105 = torch.constant.int 2 - %int32_5106 = torch.constant.int 32 + %4870 = torch.prim.ListConstruct %int1_5093, %int1_5094, %int1_5095 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4871 = torch.aten.view %4869, %4870 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_5096 = torch.constant.int 32 + %4872 = torch.aten.mul.Scalar %4859, %int32_5096 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int20 = torch.constant.int 20 + %int1_5097 = torch.constant.int 1 + %4873 = torch.aten.add.Scalar %4872, %int20, %int1_5097 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_5098 = torch.constant.int 2 + %4874 = torch.aten.mul.Scalar %4873, %int2_5098 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_5099 = torch.constant.int 1 + %4875 = torch.aten.add.Tensor %4874, %4871, %int1_5099 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_5100 = torch.constant.int 8 + %4876 = torch.aten.mul.Scalar %4875, %int8_5100 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_5101 = torch.constant.int 1 + %4877 = torch.aten.add.Tensor %4876, %4865, %int1_5101 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_5102 = torch.constant.int 32 + %4878 = torch.aten.mul.Scalar %4877, %int32_5102 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_5103 = torch.constant.int 1 + %4879 = torch.aten.add.Tensor %4878, %4862, %int1_5103 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_5104 = torch.constant.int 5 + %4880 = torch.prims.convert_element_type %4854, %int5_5104 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_5105 = torch.constant.int 32 + %int2_5106 = torch.constant.int 2 %int8_5107 = torch.constant.int 8 - %int128_5108 = torch.constant.int 128 - %4180 = torch.prim.ListConstruct %437, %int32_5104, %int2_5105, %int32_5106, %int8_5107, %int128_5108 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4181 = torch.aten.view %4170, %4180 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4181, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5109 = torch.constant.int 32 - %4182 = torch.aten.mul.int %437, %int32_5109 : !torch.int, !torch.int -> !torch.int - %int2_5110 = torch.constant.int 2 - %int32_5111 = torch.constant.int 32 - %int8_5112 = torch.constant.int 8 - %int128_5113 = torch.constant.int 128 - %4183 = torch.prim.ListConstruct %4182, %int2_5110, %int32_5111, %int8_5112, %int128_5113 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4184 = torch.aten.view %4181, %4183 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %4184, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_5114 = torch.constant.int 0 - %4185 = torch.aten.index_select %4184, %int0_5114, %4179 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %4185, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_5115 = torch.constant.int 4 - %int2_5116 = torch.constant.int 2 - %int32_5117 = torch.constant.int 32 - %int8_5118 = torch.constant.int 8 - %int128_5119 = torch.constant.int 128 - %4186 = torch.prim.ListConstruct %int4_5115, %358, %int2_5116, %int32_5117, %int8_5118, %int128_5119 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4187 = torch.aten.view %4185, %4186 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4187, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_5120 = torch.constant.int 0 - %int0_5121 = torch.constant.int 0 - %int9223372036854775807_5122 = torch.constant.int 9223372036854775807 - %int1_5123 = torch.constant.int 1 - %4188 = torch.aten.slice.Tensor %4187, %int0_5120, %int0_5121, %int9223372036854775807_5122, %int1_5123 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4188, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_5124 = torch.constant.int 1 - %int0_5125 = torch.constant.int 0 - %int9223372036854775807_5126 = torch.constant.int 9223372036854775807 + %int32_5108 = torch.constant.int 32 + %int128_5109 = torch.constant.int 128 + %4881 = torch.prim.ListConstruct %456, %int32_5105, %int2_5106, %int8_5107, %int32_5108, %int128_5109 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4882 = torch.aten.view %4702, %4881 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4882, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_5110 = torch.constant.int 128 + %4883 = torch.prim.ListConstruct %596, %int128_5110 : (!torch.int, !torch.int) -> !torch.list + %4884 = torch.aten.view %4882, %4883 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4884, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %4885 = torch.prim.ListConstruct %4879 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_5111 = torch.constant.bool false + %4886 = torch.aten.index_put %4884, %4885, %4880, %false_5111 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4886, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_5112 = torch.constant.int 32 + %int2_5113 = torch.constant.int 2 + %int8_5114 = torch.constant.int 8 + %int32_5115 = torch.constant.int 32 + %int128_5116 = torch.constant.int 128 + %4887 = torch.prim.ListConstruct %456, %int32_5112, %int2_5113, %int8_5114, %int32_5115, %int128_5116 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4888 = torch.aten.view %4886, %4887 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4888, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_5117 = torch.constant.int 2097152 + %4889 = torch.prim.ListConstruct %456, %int2097152_5117 : (!torch.int, !torch.int) -> !torch.list + %4890 = torch.aten.view %4888, %4889 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4890, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_5118 = torch.constant.int 32 + %int2_5119 = torch.constant.int 2 + %int8_5120 = torch.constant.int 8 + %int32_5121 = torch.constant.int 32 + %int128_5122 = torch.constant.int 128 + %4891 = torch.prim.ListConstruct %456, %int32_5118, %int2_5119, %int8_5120, %int32_5121, %int128_5122 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4892 = torch.aten.view %4890, %4891 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4892, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_5123 = torch.constant.int 128 + %4893 = torch.prim.ListConstruct %596, %int128_5123 : (!torch.int, !torch.int) -> !torch.list + %4894 = torch.aten.view %4892, %4893 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4894, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_5124 = torch.constant.none + %4895 = torch.aten.clone %287, %none_5124 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4896 = torch.aten.detach %4895 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4897 = torch.aten.detach %4896 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4898 = torch.aten.detach %4897 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_5125 = torch.constant.int 1 + %int1_5126 = torch.constant.int 1 %int1_5127 = torch.constant.int 1 - %4189 = torch.aten.slice.Tensor %4188, %int1_5124, %int0_5125, %int9223372036854775807_5126, %int1_5127 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4189, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_5128 = torch.constant.int 2 - %int0_5129 = torch.constant.int 0 - %4190 = torch.aten.select.int %4189, %int2_5128, %int0_5129 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4190, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_5130 = torch.constant.int 32 - %4191 = torch.aten.mul.int %358, %int32_5130 : !torch.int, !torch.int -> !torch.int + %4899 = torch.prim.ListConstruct %int1_5125, %int1_5126, %int1_5127 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4900 = torch.aten.view %4898, %4899 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_5128 = torch.constant.int 32 + %4901 = torch.aten.mul.Scalar %4859, %int32_5128 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int20_5129 = torch.constant.int 20 + %int1_5130 = torch.constant.int 1 + %4902 = torch.aten.add.Scalar %4901, %int20_5129, %int1_5130 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> %int2_5131 = torch.constant.int 2 - %int0_5132 = torch.constant.int 0 - %int1_5133 = torch.constant.int 1 - %4192 = torch.aten.slice.Tensor %4190, %int2_5131, %int0_5132, %4191, %int1_5133 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4192, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_5134 = torch.constant.int 0 - %4193 = torch.aten.clone %4192, %int0_5134 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4193, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_5135 = torch.constant.int 1 - %4194 = torch.aten.size.int %4189, %int1_5135 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_5136 = torch.constant.int 32 - %4195 = torch.aten.mul.int %4194, %int32_5136 : !torch.int, !torch.int -> !torch.int - %int4_5137 = torch.constant.int 4 - %int8_5138 = torch.constant.int 8 - %int128_5139 = torch.constant.int 128 - %4196 = torch.prim.ListConstruct %int4_5137, %4195, %int8_5138, %int128_5139 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4197 = torch.aten._unsafe_view %4193, %4196 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4197, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_5140 = torch.constant.int 0 - %int0_5141 = torch.constant.int 0 - %int9223372036854775807_5142 = torch.constant.int 9223372036854775807 - %int1_5143 = torch.constant.int 1 - %4198 = torch.aten.slice.Tensor %4197, %int0_5140, %int0_5141, %int9223372036854775807_5142, %int1_5143 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4198, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_5144 = torch.constant.int 0 - %int0_5145 = torch.constant.int 0 - %int9223372036854775807_5146 = torch.constant.int 9223372036854775807 - %int1_5147 = torch.constant.int 1 - %4199 = torch.aten.slice.Tensor %4187, %int0_5144, %int0_5145, %int9223372036854775807_5146, %int1_5147 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4199, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_5148 = torch.constant.int 1 - %int0_5149 = torch.constant.int 0 - %int9223372036854775807_5150 = torch.constant.int 9223372036854775807 - %int1_5151 = torch.constant.int 1 - %4200 = torch.aten.slice.Tensor %4199, %int1_5148, %int0_5149, %int9223372036854775807_5150, %int1_5151 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4200, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_5152 = torch.constant.int 2 - %int1_5153 = torch.constant.int 1 - %4201 = torch.aten.select.int %4200, %int2_5152, %int1_5153 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4201, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_5154 = torch.constant.int 2 - %int0_5155 = torch.constant.int 0 - %int1_5156 = torch.constant.int 1 - %4202 = torch.aten.slice.Tensor %4201, %int2_5154, %int0_5155, %4191, %int1_5156 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4202, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_5157 = torch.constant.int 0 - %4203 = torch.aten.clone %4202, %int0_5157 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4203, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_5158 = torch.constant.int 1 - %4204 = torch.aten.size.int %4200, %int1_5158 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_5159 = torch.constant.int 32 - %4205 = torch.aten.mul.int %4204, %int32_5159 : !torch.int, !torch.int -> !torch.int + %4903 = torch.aten.mul.Scalar %4902, %int2_5131 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_5132 = torch.constant.int 1 + %4904 = torch.aten.add.Tensor %4903, %4900, %int1_5132 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_5133 = torch.constant.int 8 + %4905 = torch.aten.mul.Scalar %4904, %int8_5133 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_5134 = torch.constant.int 1 + %4906 = torch.aten.add.Tensor %4905, %4865, %int1_5134 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_5135 = torch.constant.int 32 + %4907 = torch.aten.mul.Scalar %4906, %int32_5135 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_5136 = torch.constant.int 1 + %4908 = torch.aten.add.Tensor %4907, %4862, %int1_5136 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_5137 = torch.constant.int 5 + %4909 = torch.prims.convert_element_type %4834, %int5_5137 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %4910 = torch.prim.ListConstruct %4908 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_5138 = torch.constant.bool false + %4911 = torch.aten.index_put %4894, %4910, %4909, %false_5138 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %4911, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_5139 = torch.constant.int 32 + %int2_5140 = torch.constant.int 2 + %int8_5141 = torch.constant.int 8 + %int32_5142 = torch.constant.int 32 + %int128_5143 = torch.constant.int 128 + %4912 = torch.prim.ListConstruct %456, %int32_5139, %int2_5140, %int8_5141, %int32_5142, %int128_5143 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4913 = torch.aten.view %4911, %4912 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4913, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_5144 = torch.constant.int 2097152 + %4914 = torch.prim.ListConstruct %456, %int2097152_5144 : (!torch.int, !torch.int) -> !torch.list + %4915 = torch.aten.view %4913, %4914 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %4915, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_5145 = torch.constant.none + %4916 = torch.aten.clone %288, %none_5145 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4917 = torch.aten.detach %4916 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4918 = torch.aten.detach %4917 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4919 = torch.aten.detach %4918 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_5146 = torch.constant.none + %4920 = torch.aten.clone %289, %none_5146 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4921 = torch.aten.detach %4920 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4922 = torch.aten.detach %4921 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4923 = torch.aten.detach %4922 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_5147 = torch.constant.none + %4924 = torch.aten.clone %290, %none_5147 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %4925 = torch.aten.detach %4924 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4926 = torch.aten.detach %4925 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %4927 = torch.aten.detach %4926 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_5148 = torch.constant.int 32 + %int2_5149 = torch.constant.int 2 + %int8_5150 = torch.constant.int 8 + %int32_5151 = torch.constant.int 32 + %int128_5152 = torch.constant.int 128 + %4928 = torch.prim.ListConstruct %456, %int32_5148, %int2_5149, %int8_5150, %int32_5151, %int128_5152 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4929 = torch.aten.view %4915, %4928 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %4929, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %4930 = torch_c.to_builtin_tensor %4929 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %4931 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_5153 = tensor.cast %4931 : tensor<4x?xi64> to tensor + %4932 = torch_c.to_builtin_tensor %4919 : !torch.vtensor<[],si64> -> tensor + %4933 = torch_c.to_builtin_tensor %4923 : !torch.vtensor<[],si64> -> tensor + %4934 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%4930, %cast_5153, %4932, %4933) : (tensor, tensor, tensor, tensor) -> tensor + %cast_5154 = tensor.cast %4934 : tensor to tensor<4x?x8x32x128xf16> + %4935 = torch_c.from_builtin_tensor %cast_5154 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %4935, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %4936 = torch_c.to_builtin_tensor %4929 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %4937 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_5155 = tensor.cast %4937 : tensor<4x?xi64> to tensor + %4938 = torch_c.to_builtin_tensor %4919 : !torch.vtensor<[],si64> -> tensor + %4939 = torch_c.to_builtin_tensor %4927 : !torch.vtensor<[],si64> -> tensor + %4940 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%4936, %cast_5155, %4938, %4939) : (tensor, tensor, tensor, tensor) -> tensor + %cast_5156 = tensor.cast %4940 : tensor to tensor<4x?x8x32x128xf16> + %4941 = torch_c.from_builtin_tensor %cast_5156 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %4941, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_5157 = torch.constant.int 2 + %int3_5158 = torch.constant.int 3 + %4942 = torch.aten.transpose.int %4935, %int2_5157, %int3_5158 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4942, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_5159 = torch.constant.int 0 + %4943 = torch.aten.clone %4942, %int0_5159 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4943, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int4_5160 = torch.constant.int 4 %int8_5161 = torch.constant.int 8 %int128_5162 = torch.constant.int 128 - %4206 = torch.prim.ListConstruct %int4_5160, %4205, %int8_5161, %int128_5162 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4207 = torch.aten._unsafe_view %4203, %4206 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4207, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_5163 = torch.constant.int 0 - %int0_5164 = torch.constant.int 0 - %int9223372036854775807_5165 = torch.constant.int 9223372036854775807 - %int1_5166 = torch.constant.int 1 - %4208 = torch.aten.slice.Tensor %4207, %int0_5163, %int0_5164, %int9223372036854775807_5165, %int1_5166 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4208, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_5167 = torch.constant.int -2 - %4209 = torch.aten.unsqueeze %4198, %int-2_5167 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4209, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_5168 = torch.constant.int 1 - %4210 = torch.aten.size.int %4197, %int1_5168 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_5169 = torch.constant.int 4 - %int8_5170 = torch.constant.int 8 - %int4_5171 = torch.constant.int 4 - %int128_5172 = torch.constant.int 128 - %4211 = torch.prim.ListConstruct %int4_5169, %4210, %int8_5170, %int4_5171, %int128_5172 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_5173 = torch.constant.bool false - %4212 = torch.aten.expand %4209, %4211, %false_5173 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4212, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_5174 = torch.constant.int 0 - %4213 = torch.aten.clone %4212, %int0_5174 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4213, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_5175 = torch.constant.int 4 - %int32_5176 = torch.constant.int 32 - %int128_5177 = torch.constant.int 128 - %4214 = torch.prim.ListConstruct %int4_5175, %4210, %int32_5176, %int128_5177 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4215 = torch.aten._unsafe_view %4213, %4214 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4215, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_5178 = torch.constant.int -2 - %4216 = torch.aten.unsqueeze %4208, %int-2_5178 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4216, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_5179 = torch.constant.int 1 - %4217 = torch.aten.size.int %4207, %int1_5179 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int + %4944 = torch.prim.ListConstruct %int4_5160, %457, %int8_5161, %int128_5162 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4945 = torch.aten._unsafe_view %4943, %4944 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4945, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_5163 = torch.constant.int 2 + %int3_5164 = torch.constant.int 3 + %4946 = torch.aten.transpose.int %4941, %int2_5163, %int3_5164 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4946, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_5165 = torch.constant.int 0 + %4947 = torch.aten.clone %4946, %int0_5165 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %4947, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_5166 = torch.constant.int 4 + %int8_5167 = torch.constant.int 8 + %int128_5168 = torch.constant.int 128 + %4948 = torch.prim.ListConstruct %int4_5166, %457, %int8_5167, %int128_5168 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4949 = torch.aten._unsafe_view %4947, %4948 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %4949, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_5169 = torch.constant.int -2 + %4950 = torch.aten.unsqueeze %4945, %int-2_5169 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4950, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_5170 = torch.constant.int 4 + %int8_5171 = torch.constant.int 8 + %int4_5172 = torch.constant.int 4 + %int128_5173 = torch.constant.int 128 + %4951 = torch.prim.ListConstruct %int4_5170, %457, %int8_5171, %int4_5172, %int128_5173 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_5174 = torch.constant.bool false + %4952 = torch.aten.expand %4950, %4951, %false_5174 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4952, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_5175 = torch.constant.int 0 + %4953 = torch.aten.clone %4952, %int0_5175 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4953, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_5176 = torch.constant.int 4 + %int32_5177 = torch.constant.int 32 + %int128_5178 = torch.constant.int 128 + %4954 = torch.prim.ListConstruct %int4_5176, %457, %int32_5177, %int128_5178 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4955 = torch.aten._unsafe_view %4953, %4954 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4955, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_5179 = torch.constant.int -2 + %4956 = torch.aten.unsqueeze %4949, %int-2_5179 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %4956, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> %int4_5180 = torch.constant.int 4 %int8_5181 = torch.constant.int 8 %int4_5182 = torch.constant.int 4 %int128_5183 = torch.constant.int 128 - %4218 = torch.prim.ListConstruct %int4_5180, %4217, %int8_5181, %int4_5182, %int128_5183 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4957 = torch.prim.ListConstruct %int4_5180, %457, %int8_5181, %int4_5182, %int128_5183 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %false_5184 = torch.constant.bool false - %4219 = torch.aten.expand %4216, %4218, %false_5184 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4219, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %4958 = torch.aten.expand %4956, %4957, %false_5184 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4958, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int0_5185 = torch.constant.int 0 - %4220 = torch.aten.clone %4219, %int0_5185 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4220, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %4959 = torch.aten.clone %4958, %int0_5185 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %4959, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_5186 = torch.constant.int 4 %int32_5187 = torch.constant.int 32 %int128_5188 = torch.constant.int 128 - %4221 = torch.prim.ListConstruct %int4_5186, %4217, %int32_5187, %int128_5188 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4222 = torch.aten._unsafe_view %4220, %4221 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4222, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %4960 = torch.prim.ListConstruct %int4_5186, %457, %int32_5187, %int128_5188 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %4961 = torch.aten._unsafe_view %4959, %4960 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %4961, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int1_5189 = torch.constant.int 1 %int2_5190 = torch.constant.int 2 - %4223 = torch.aten.transpose.int %4103, %int1_5189, %int2_5190 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %4962 = torch.aten.transpose.int %4844, %int1_5189, %int2_5190 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_5191 = torch.constant.int 1 %int2_5192 = torch.constant.int 2 - %4224 = torch.aten.transpose.int %4215, %int1_5191, %int2_5192 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4224, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %4963 = torch.aten.transpose.int %4955, %int1_5191, %int2_5192 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4963, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_5193 = torch.constant.int 1 %int2_5194 = torch.constant.int 2 - %4225 = torch.aten.transpose.int %4222, %int1_5193, %int2_5194 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4225, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %4964 = torch.aten.transpose.int %4961, %int1_5193, %int2_5194 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %4964, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %float0.000000e00_5195 = torch.constant.float 0.000000e+00 %false_5196 = torch.constant.bool false %none_5197 = torch.constant.none - %4226:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4223, %4224, %4225, %float0.000000e00_5195, %false_5196, %368, %none_5197) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %4965:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4962, %4963, %4964, %float0.000000e00_5195, %false_5196, %470, %none_5197) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) %int1_5198 = torch.constant.int 1 %int2_5199 = torch.constant.int 2 - %4227 = torch.aten.transpose.int %4226#0, %int1_5198, %int2_5199 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %4966 = torch.aten.transpose.int %4965#0, %int1_5198, %int2_5199 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> %int4_5200 = torch.constant.int 4 %int1_5201 = torch.constant.int 1 %int4096_5202 = torch.constant.int 4096 - %4228 = torch.prim.ListConstruct %int4_5200, %int1_5201, %int4096_5202 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4229 = torch.aten.view %4227, %4228 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %4967 = torch.prim.ListConstruct %int4_5200, %int1_5201, %int4096_5202 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4968 = torch.aten.view %4966, %4967 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int-2_5203 = torch.constant.int -2 %int-1_5204 = torch.constant.int -1 - %4230 = torch.aten.transpose.int %205, %int-2_5203, %int-1_5204 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_5205 = torch.constant.int 4 - %int4096_5206 = torch.constant.int 4096 - %4231 = torch.prim.ListConstruct %int4_5205, %int4096_5206 : (!torch.int, !torch.int) -> !torch.list - %4232 = torch.aten.view %4229, %4231 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4233 = torch.aten.mm %4232, %4230 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_5207 = torch.constant.int 4 - %int1_5208 = torch.constant.int 1 - %int4096_5209 = torch.constant.int 4096 - %4234 = torch.prim.ListConstruct %int4_5207, %int1_5208, %int4096_5209 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4235 = torch.aten.view %4233, %4234 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_5210 = torch.constant.int 1 - %4236 = torch.aten.add.Tensor %4063, %4235, %int1_5210 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_5211 = torch.constant.int 6 - %4237 = torch.prims.convert_element_type %4236, %int6_5211 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_5212 = torch.constant.int 2 - %4238 = torch.aten.pow.Tensor_Scalar %4237, %int2_5212 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_5213 = torch.constant.int -1 - %4239 = torch.prim.ListConstruct %int-1_5213 : (!torch.int) -> !torch.list - %true_5214 = torch.constant.bool true - %none_5215 = torch.constant.none - %4240 = torch.aten.mean.dim %4238, %4239, %true_5214, %none_5215 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_5216 = torch.constant.float 9.9999997473787516E-6 - %int1_5217 = torch.constant.int 1 - %4241 = torch.aten.add.Scalar %4240, %float9.999990e-06_5216, %int1_5217 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %4242 = torch.aten.rsqrt %4241 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %4243 = torch.aten.mul.Tensor %4237, %4242 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_5218 = torch.constant.int 5 - %4244 = torch.prims.convert_element_type %4243, %int5_5218 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %4245 = torch.aten.mul.Tensor %206, %4244 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %4969 = torch.aten.transpose.int %291, %int-2_5203, %int-1_5204 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_5205 = torch.constant.int 5 + %4970 = torch.prims.convert_element_type %4969, %int5_5205 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_5206 = torch.constant.int 4 + %int4096_5207 = torch.constant.int 4096 + %4971 = torch.prim.ListConstruct %int4_5206, %int4096_5207 : (!torch.int, !torch.int) -> !torch.list + %4972 = torch.aten.view %4968, %4971 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4973 = torch.aten.mm %4972, %4970 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_5208 = torch.constant.int 4 + %int1_5209 = torch.constant.int 1 + %int4096_5210 = torch.constant.int 4096 + %4974 = torch.prim.ListConstruct %int4_5208, %int1_5209, %int4096_5210 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4975 = torch.aten.view %4973, %4974 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_5211 = torch.constant.int 1 + %4976 = torch.aten.add.Tensor %4797, %4975, %int1_5211 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_5212 = torch.constant.int 6 + %4977 = torch.prims.convert_element_type %4976, %int6_5212 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_5213 = torch.constant.int 2 + %4978 = torch.aten.pow.Tensor_Scalar %4977, %int2_5213 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_5214 = torch.constant.int -1 + %4979 = torch.prim.ListConstruct %int-1_5214 : (!torch.int) -> !torch.list + %true_5215 = torch.constant.bool true + %none_5216 = torch.constant.none + %4980 = torch.aten.mean.dim %4978, %4979, %true_5215, %none_5216 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_5217 = torch.constant.float 9.9999997473787516E-6 + %int1_5218 = torch.constant.int 1 + %4981 = torch.aten.add.Scalar %4980, %float9.999990e-06_5217, %int1_5218 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %4982 = torch.aten.rsqrt %4981 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %4983 = torch.aten.mul.Tensor %4977, %4982 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> %int5_5219 = torch.constant.int 5 - %4246 = torch.prims.convert_element_type %4245, %int5_5219 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_5220 = torch.constant.int -2 - %int-1_5221 = torch.constant.int -1 - %4247 = torch.aten.transpose.int %207, %int-2_5220, %int-1_5221 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_5222 = torch.constant.int 4 - %int4096_5223 = torch.constant.int 4096 - %4248 = torch.prim.ListConstruct %int4_5222, %int4096_5223 : (!torch.int, !torch.int) -> !torch.list - %4249 = torch.aten.view %4246, %4248 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4250 = torch.aten.mm %4249, %4247 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %4984 = torch.prims.convert_element_type %4983, %int5_5219 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %4985 = torch.aten.mul.Tensor %292, %4984 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_5220 = torch.constant.int 5 + %4986 = torch.prims.convert_element_type %4985, %int5_5220 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_5221 = torch.constant.int -2 + %int-1_5222 = torch.constant.int -1 + %4987 = torch.aten.transpose.int %293, %int-2_5221, %int-1_5222 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_5223 = torch.constant.int 5 + %4988 = torch.prims.convert_element_type %4987, %int5_5223 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> %int4_5224 = torch.constant.int 4 - %int1_5225 = torch.constant.int 1 - %int14336_5226 = torch.constant.int 14336 - %4251 = torch.prim.ListConstruct %int4_5224, %int1_5225, %int14336_5226 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4252 = torch.aten.view %4250, %4251 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %4253 = torch.aten.silu %4252 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_5227 = torch.constant.int -2 - %int-1_5228 = torch.constant.int -1 - %4254 = torch.aten.transpose.int %208, %int-2_5227, %int-1_5228 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_5229 = torch.constant.int 4 - %int4096_5230 = torch.constant.int 4096 - %4255 = torch.prim.ListConstruct %int4_5229, %int4096_5230 : (!torch.int, !torch.int) -> !torch.list - %4256 = torch.aten.view %4246, %4255 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4257 = torch.aten.mm %4256, %4254 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_5231 = torch.constant.int 4 - %int1_5232 = torch.constant.int 1 - %int14336_5233 = torch.constant.int 14336 - %4258 = torch.prim.ListConstruct %int4_5231, %int1_5232, %int14336_5233 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4259 = torch.aten.view %4257, %4258 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %4260 = torch.aten.mul.Tensor %4253, %4259 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_5234 = torch.constant.int -2 - %int-1_5235 = torch.constant.int -1 - %4261 = torch.aten.transpose.int %209, %int-2_5234, %int-1_5235 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_5236 = torch.constant.int 4 - %int14336_5237 = torch.constant.int 14336 - %4262 = torch.prim.ListConstruct %int4_5236, %int14336_5237 : (!torch.int, !torch.int) -> !torch.list - %4263 = torch.aten.view %4260, %4262 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %4264 = torch.aten.mm %4263, %4261 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_5238 = torch.constant.int 4 - %int1_5239 = torch.constant.int 1 - %int4096_5240 = torch.constant.int 4096 - %4265 = torch.prim.ListConstruct %int4_5238, %int1_5239, %int4096_5240 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4266 = torch.aten.view %4264, %4265 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_5241 = torch.constant.int 1 - %4267 = torch.aten.add.Tensor %4236, %4266, %int1_5241 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_5242 = torch.constant.int 6 - %4268 = torch.prims.convert_element_type %4267, %int6_5242 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_5243 = torch.constant.int 2 - %4269 = torch.aten.pow.Tensor_Scalar %4268, %int2_5243 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_5244 = torch.constant.int -1 - %4270 = torch.prim.ListConstruct %int-1_5244 : (!torch.int) -> !torch.list - %true_5245 = torch.constant.bool true - %none_5246 = torch.constant.none - %4271 = torch.aten.mean.dim %4269, %4270, %true_5245, %none_5246 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_5247 = torch.constant.float 9.9999997473787516E-6 - %int1_5248 = torch.constant.int 1 - %4272 = torch.aten.add.Scalar %4271, %float9.999990e-06_5247, %int1_5248 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %4273 = torch.aten.rsqrt %4272 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %4274 = torch.aten.mul.Tensor %4268, %4273 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_5249 = torch.constant.int 5 - %4275 = torch.prims.convert_element_type %4274, %int5_5249 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %4276 = torch.aten.mul.Tensor %210, %4275 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_5250 = torch.constant.int 5 - %4277 = torch.prims.convert_element_type %4276, %int5_5250 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_5251 = torch.constant.int -2 - %int-1_5252 = torch.constant.int -1 - %4278 = torch.aten.transpose.int %211, %int-2_5251, %int-1_5252 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_5253 = torch.constant.int 4 - %int4096_5254 = torch.constant.int 4096 - %4279 = torch.prim.ListConstruct %int4_5253, %int4096_5254 : (!torch.int, !torch.int) -> !torch.list - %4280 = torch.aten.view %4277, %4279 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4281 = torch.aten.mm %4280, %4278 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_5255 = torch.constant.int 4 - %int1_5256 = torch.constant.int 1 - %int4096_5257 = torch.constant.int 4096 - %4282 = torch.prim.ListConstruct %int4_5255, %int1_5256, %int4096_5257 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4283 = torch.aten.view %4281, %4282 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_5258 = torch.constant.int -2 - %int-1_5259 = torch.constant.int -1 - %4284 = torch.aten.transpose.int %212, %int-2_5258, %int-1_5259 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4096_5225 = torch.constant.int 4096 + %4989 = torch.prim.ListConstruct %int4_5224, %int4096_5225 : (!torch.int, !torch.int) -> !torch.list + %4990 = torch.aten.view %4986, %4989 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4991 = torch.aten.mm %4990, %4988 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_5226 = torch.constant.int 4 + %int1_5227 = torch.constant.int 1 + %int14336_5228 = torch.constant.int 14336 + %4992 = torch.prim.ListConstruct %int4_5226, %int1_5227, %int14336_5228 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %4993 = torch.aten.view %4991, %4992 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %4994 = torch.aten.silu %4993 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_5229 = torch.constant.int -2 + %int-1_5230 = torch.constant.int -1 + %4995 = torch.aten.transpose.int %294, %int-2_5229, %int-1_5230 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_5231 = torch.constant.int 5 + %4996 = torch.prims.convert_element_type %4995, %int5_5231 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_5232 = torch.constant.int 4 + %int4096_5233 = torch.constant.int 4096 + %4997 = torch.prim.ListConstruct %int4_5232, %int4096_5233 : (!torch.int, !torch.int) -> !torch.list + %4998 = torch.aten.view %4986, %4997 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %4999 = torch.aten.mm %4998, %4996 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_5234 = torch.constant.int 4 + %int1_5235 = torch.constant.int 1 + %int14336_5236 = torch.constant.int 14336 + %5000 = torch.prim.ListConstruct %int4_5234, %int1_5235, %int14336_5236 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5001 = torch.aten.view %4999, %5000 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %5002 = torch.aten.mul.Tensor %4994, %5001 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_5237 = torch.constant.int -2 + %int-1_5238 = torch.constant.int -1 + %5003 = torch.aten.transpose.int %295, %int-2_5237, %int-1_5238 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_5239 = torch.constant.int 5 + %5004 = torch.prims.convert_element_type %5003, %int5_5239 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_5240 = torch.constant.int 4 + %int14336_5241 = torch.constant.int 14336 + %5005 = torch.prim.ListConstruct %int4_5240, %int14336_5241 : (!torch.int, !torch.int) -> !torch.list + %5006 = torch.aten.view %5002, %5005 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %5007 = torch.aten.mm %5006, %5004 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_5242 = torch.constant.int 4 + %int1_5243 = torch.constant.int 1 + %int4096_5244 = torch.constant.int 4096 + %5008 = torch.prim.ListConstruct %int4_5242, %int1_5243, %int4096_5244 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5009 = torch.aten.view %5007, %5008 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_5245 = torch.constant.int 1 + %5010 = torch.aten.add.Tensor %4976, %5009, %int1_5245 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_5246 = torch.constant.int 6 + %5011 = torch.prims.convert_element_type %5010, %int6_5246 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_5247 = torch.constant.int 2 + %5012 = torch.aten.pow.Tensor_Scalar %5011, %int2_5247 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_5248 = torch.constant.int -1 + %5013 = torch.prim.ListConstruct %int-1_5248 : (!torch.int) -> !torch.list + %true_5249 = torch.constant.bool true + %none_5250 = torch.constant.none + %5014 = torch.aten.mean.dim %5012, %5013, %true_5249, %none_5250 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_5251 = torch.constant.float 9.9999997473787516E-6 + %int1_5252 = torch.constant.int 1 + %5015 = torch.aten.add.Scalar %5014, %float9.999990e-06_5251, %int1_5252 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %5016 = torch.aten.rsqrt %5015 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %5017 = torch.aten.mul.Tensor %5011, %5016 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_5253 = torch.constant.int 5 + %5018 = torch.prims.convert_element_type %5017, %int5_5253 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %5019 = torch.aten.mul.Tensor %296, %5018 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_5254 = torch.constant.int 5 + %5020 = torch.prims.convert_element_type %5019, %int5_5254 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_5255 = torch.constant.int -2 + %int-1_5256 = torch.constant.int -1 + %5021 = torch.aten.transpose.int %297, %int-2_5255, %int-1_5256 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_5257 = torch.constant.int 5 + %5022 = torch.prims.convert_element_type %5021, %int5_5257 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_5258 = torch.constant.int 4 + %int4096_5259 = torch.constant.int 4096 + %5023 = torch.prim.ListConstruct %int4_5258, %int4096_5259 : (!torch.int, !torch.int) -> !torch.list + %5024 = torch.aten.view %5020, %5023 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5025 = torch.aten.mm %5024, %5022 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_5260 = torch.constant.int 4 - %int4096_5261 = torch.constant.int 4096 - %4285 = torch.prim.ListConstruct %int4_5260, %int4096_5261 : (!torch.int, !torch.int) -> !torch.list - %4286 = torch.aten.view %4277, %4285 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4287 = torch.aten.mm %4286, %4284 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_5262 = torch.constant.int 4 - %int1_5263 = torch.constant.int 1 - %int1024_5264 = torch.constant.int 1024 - %4288 = torch.prim.ListConstruct %int4_5262, %int1_5263, %int1024_5264 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4289 = torch.aten.view %4287, %4288 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_5265 = torch.constant.int -2 - %int-1_5266 = torch.constant.int -1 - %4290 = torch.aten.transpose.int %213, %int-2_5265, %int-1_5266 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_5267 = torch.constant.int 4 - %int4096_5268 = torch.constant.int 4096 - %4291 = torch.prim.ListConstruct %int4_5267, %int4096_5268 : (!torch.int, !torch.int) -> !torch.list - %4292 = torch.aten.view %4277, %4291 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4293 = torch.aten.mm %4292, %4290 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_5269 = torch.constant.int 4 - %int1_5270 = torch.constant.int 1 - %int1024_5271 = torch.constant.int 1024 - %4294 = torch.prim.ListConstruct %int4_5269, %int1_5270, %int1024_5271 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4295 = torch.aten.view %4293, %4294 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_5272 = torch.constant.int 4 - %int1_5273 = torch.constant.int 1 - %int32_5274 = torch.constant.int 32 - %int128_5275 = torch.constant.int 128 - %4296 = torch.prim.ListConstruct %int4_5272, %int1_5273, %int32_5274, %int128_5275 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4297 = torch.aten.view %4283, %4296 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int1_5261 = torch.constant.int 1 + %int4096_5262 = torch.constant.int 4096 + %5026 = torch.prim.ListConstruct %int4_5260, %int1_5261, %int4096_5262 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5027 = torch.aten.view %5025, %5026 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_5263 = torch.constant.int -2 + %int-1_5264 = torch.constant.int -1 + %5028 = torch.aten.transpose.int %298, %int-2_5263, %int-1_5264 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_5265 = torch.constant.int 5 + %5029 = torch.prims.convert_element_type %5028, %int5_5265 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_5266 = torch.constant.int 4 + %int4096_5267 = torch.constant.int 4096 + %5030 = torch.prim.ListConstruct %int4_5266, %int4096_5267 : (!torch.int, !torch.int) -> !torch.list + %5031 = torch.aten.view %5020, %5030 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5032 = torch.aten.mm %5031, %5029 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_5268 = torch.constant.int 4 + %int1_5269 = torch.constant.int 1 + %int1024_5270 = torch.constant.int 1024 + %5033 = torch.prim.ListConstruct %int4_5268, %int1_5269, %int1024_5270 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5034 = torch.aten.view %5032, %5033 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_5271 = torch.constant.int -2 + %int-1_5272 = torch.constant.int -1 + %5035 = torch.aten.transpose.int %299, %int-2_5271, %int-1_5272 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_5273 = torch.constant.int 5 + %5036 = torch.prims.convert_element_type %5035, %int5_5273 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_5274 = torch.constant.int 4 + %int4096_5275 = torch.constant.int 4096 + %5037 = torch.prim.ListConstruct %int4_5274, %int4096_5275 : (!torch.int, !torch.int) -> !torch.list + %5038 = torch.aten.view %5020, %5037 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5039 = torch.aten.mm %5038, %5036 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> %int4_5276 = torch.constant.int 4 %int1_5277 = torch.constant.int 1 - %int8_5278 = torch.constant.int 8 - %int128_5279 = torch.constant.int 128 - %4298 = torch.prim.ListConstruct %int4_5276, %int1_5277, %int8_5278, %int128_5279 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4299 = torch.aten.view %4289, %4298 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_5280 = torch.constant.int 4 - %int1_5281 = torch.constant.int 1 - %int8_5282 = torch.constant.int 8 - %int128_5283 = torch.constant.int 128 - %4300 = torch.prim.ListConstruct %int4_5280, %int1_5281, %int8_5282, %int128_5283 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4301 = torch.aten.view %4295, %4300 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_5284 = torch.constant.int 6 - %4302 = torch.prims.convert_element_type %4297, %int6_5284 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %4303 = torch_c.to_builtin_tensor %4302 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %4304 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %4305 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%4303, %4304) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %4306 = torch_c.from_builtin_tensor %4305 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_5285 = torch.constant.int 5 - %4307 = torch.prims.convert_element_type %4306, %int5_5285 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_5286 = torch.constant.int 6 - %4308 = torch.prims.convert_element_type %4299, %int6_5286 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %4309 = torch_c.to_builtin_tensor %4308 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %4310 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %4311 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%4309, %4310) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %4312 = torch_c.from_builtin_tensor %4311 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_5287 = torch.constant.int 5 - %4313 = torch.prims.convert_element_type %4312, %int5_5287 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_5288 = torch.constant.int 32 - %4314 = torch.aten.floor_divide.Scalar %arg2, %int32_5288 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_5289 = torch.constant.int 1 - %4315 = torch.aten.unsqueeze %4314, %int1_5289 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5290 = torch.constant.int 1 - %false_5291 = torch.constant.bool false - %4316 = torch.aten.gather %arg3, %int1_5290, %4315, %false_5291 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_5292 = torch.constant.int 32 - %4317 = torch.aten.remainder.Scalar %arg2, %int32_5292 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_5293 = torch.constant.int 1 - %4318 = torch.aten.unsqueeze %4317, %int1_5293 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_5294 = torch.constant.none - %4319 = torch.aten.clone %214, %none_5294 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_5295 = torch.constant.int 0 - %4320 = torch.aten.unsqueeze %4319, %int0_5295 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_5296 = torch.constant.int 4 - %int1_5297 = torch.constant.int 1 - %4321 = torch.prim.ListConstruct %int4_5296, %int1_5297 : (!torch.int, !torch.int) -> !torch.list - %int1_5298 = torch.constant.int 1 - %int1_5299 = torch.constant.int 1 - %4322 = torch.prim.ListConstruct %int1_5298, %int1_5299 : (!torch.int, !torch.int) -> !torch.list - %int4_5300 = torch.constant.int 4 - %int0_5301 = torch.constant.int 0 - %cpu_5302 = torch.constant.device "cpu" - %false_5303 = torch.constant.bool false - %4323 = torch.aten.empty_strided %4321, %4322, %int4_5300, %int0_5301, %cpu_5302, %false_5303 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int19 = torch.constant.int 19 - %4324 = torch.aten.fill.Scalar %4323, %int19 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_5304 = torch.constant.int 4 + %int1024_5278 = torch.constant.int 1024 + %5040 = torch.prim.ListConstruct %int4_5276, %int1_5277, %int1024_5278 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5041 = torch.aten.view %5039, %5040 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_5279 = torch.constant.int 4 + %int1_5280 = torch.constant.int 1 + %int32_5281 = torch.constant.int 32 + %int128_5282 = torch.constant.int 128 + %5042 = torch.prim.ListConstruct %int4_5279, %int1_5280, %int32_5281, %int128_5282 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5043 = torch.aten.view %5027, %5042 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_5283 = torch.constant.int 4 + %int1_5284 = torch.constant.int 1 + %int8_5285 = torch.constant.int 8 + %int128_5286 = torch.constant.int 128 + %5044 = torch.prim.ListConstruct %int4_5283, %int1_5284, %int8_5285, %int128_5286 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5045 = torch.aten.view %5034, %5044 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_5287 = torch.constant.int 4 + %int1_5288 = torch.constant.int 1 + %int8_5289 = torch.constant.int 8 + %int128_5290 = torch.constant.int 128 + %5046 = torch.prim.ListConstruct %int4_5287, %int1_5288, %int8_5289, %int128_5290 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5047 = torch.aten.view %5041, %5046 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_5291 = torch.constant.int 1 + %int2_5292 = torch.constant.int 2 + %5048 = torch.aten.transpose.int %5043, %int1_5291, %int2_5292 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %5049 = torch.aten.mul.Tensor %5048, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_5293 = torch.constant.int 3 + %int0_5294 = torch.constant.int 0 + %int64_5295 = torch.constant.int 64 + %int1_5296 = torch.constant.int 1 + %5050 = torch.aten.slice.Tensor %5048, %int3_5293, %int0_5294, %int64_5295, %int1_5296 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_5297 = torch.constant.int 3 + %int64_5298 = torch.constant.int 64 + %int9223372036854775807_5299 = torch.constant.int 9223372036854775807 + %int1_5300 = torch.constant.int 1 + %5051 = torch.aten.slice.Tensor %5048, %int3_5297, %int64_5298, %int9223372036854775807_5299, %int1_5300 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %5052 = torch.aten.neg %5051 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %5053 = torch.prim.ListConstruct %5052, %5050 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_5301 = torch.constant.int -1 + %5054 = torch.aten.cat %5053, %int-1_5301 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %5055 = torch.aten.mul.Tensor %5054, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_5302 = torch.constant.int 1 + %5056 = torch.aten.add.Tensor %5049, %5055, %int1_5302 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_5303 = torch.constant.int 1 + %int2_5304 = torch.constant.int 2 + %5057 = torch.aten.transpose.int %5056, %int1_5303, %int2_5304 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> %int1_5305 = torch.constant.int 1 - %4325 = torch.prim.ListConstruct %int4_5304, %int1_5305 : (!torch.int, !torch.int) -> !torch.list - %4326 = torch.aten.repeat %4320, %4325 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_5306 = torch.constant.int 32 - %4327 = torch.aten.mul.Scalar %4316, %int32_5306 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5307 = torch.constant.int 1 - %4328 = torch.aten.add.Tensor %4327, %4324, %int1_5307 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_5308 = torch.constant.int 2 - %4329 = torch.aten.mul.Scalar %4328, %int2_5308 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5309 = torch.constant.int 1 - %4330 = torch.aten.add.Tensor %4329, %4326, %int1_5309 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_5310 = torch.constant.int 32 - %4331 = torch.aten.mul.Scalar %4330, %int32_5310 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5311 = torch.constant.int 1 - %4332 = torch.aten.add.Tensor %4331, %4318, %int1_5311 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_5312 = torch.constant.int 32 - %int2_5313 = torch.constant.int 2 - %int32_5314 = torch.constant.int 32 - %int8_5315 = torch.constant.int 8 - %int128_5316 = torch.constant.int 128 - %4333 = torch.prim.ListConstruct %437, %int32_5312, %int2_5313, %int32_5314, %int8_5315, %int128_5316 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4334 = torch.aten.view %4170, %4333 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4334, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5317 = torch.constant.int 32 - %4335 = torch.aten.mul.int %437, %int32_5317 : !torch.int, !torch.int -> !torch.int + %int2_5306 = torch.constant.int 2 + %5058 = torch.aten.transpose.int %5045, %int1_5305, %int2_5306 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %5059 = torch.aten.mul.Tensor %5058, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_5307 = torch.constant.int 3 + %int0_5308 = torch.constant.int 0 + %int64_5309 = torch.constant.int 64 + %int1_5310 = torch.constant.int 1 + %5060 = torch.aten.slice.Tensor %5058, %int3_5307, %int0_5308, %int64_5309, %int1_5310 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_5311 = torch.constant.int 3 + %int64_5312 = torch.constant.int 64 + %int9223372036854775807_5313 = torch.constant.int 9223372036854775807 + %int1_5314 = torch.constant.int 1 + %5061 = torch.aten.slice.Tensor %5058, %int3_5311, %int64_5312, %int9223372036854775807_5313, %int1_5314 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %5062 = torch.aten.neg %5061 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %5063 = torch.prim.ListConstruct %5062, %5060 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_5315 = torch.constant.int -1 + %5064 = torch.aten.cat %5063, %int-1_5315 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %5065 = torch.aten.mul.Tensor %5064, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_5316 = torch.constant.int 1 + %5066 = torch.aten.add.Tensor %5059, %5065, %int1_5316 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_5317 = torch.constant.int 1 %int2_5318 = torch.constant.int 2 - %4336 = torch.aten.mul.int %4335, %int2_5318 : !torch.int, !torch.int -> !torch.int + %5067 = torch.aten.transpose.int %5066, %int1_5317, %int2_5318 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> %int32_5319 = torch.constant.int 32 - %4337 = torch.aten.mul.int %4336, %int32_5319 : !torch.int, !torch.int -> !torch.int - %int8_5320 = torch.constant.int 8 - %int128_5321 = torch.constant.int 128 - %4338 = torch.prim.ListConstruct %4337, %int8_5320, %int128_5321 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4339 = torch.aten.view %4334, %4338 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4339, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %4340 = torch.prim.ListConstruct %4332 : (!torch.vtensor<[4,1],si64>) -> !torch.list> + %5068 = torch.aten.floor_divide.Scalar %arg2, %int32_5319 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_5320 = torch.constant.int 1 + %5069 = torch.aten.unsqueeze %5068, %int1_5320 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_5321 = torch.constant.int 1 %false_5322 = torch.constant.bool false - %4341 = torch.aten.index_put %4339, %4340, %4313, %false_5322 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4341, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_5323 = torch.constant.int 32 - %int2_5324 = torch.constant.int 2 - %int32_5325 = torch.constant.int 32 - %int8_5326 = torch.constant.int 8 - %int128_5327 = torch.constant.int 128 - %4342 = torch.prim.ListConstruct %437, %int32_5323, %int2_5324, %int32_5325, %int8_5326, %int128_5327 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4343 = torch.aten.view %4341, %4342 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4343, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5328 = torch.constant.int 2097152 - %4344 = torch.prim.ListConstruct %437, %int2097152_5328 : (!torch.int, !torch.int) -> !torch.list - %4345 = torch.aten.view %4343, %4344 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4345, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_5329 = torch.constant.int 32 - %int2_5330 = torch.constant.int 2 - %int32_5331 = torch.constant.int 32 - %int8_5332 = torch.constant.int 8 - %int128_5333 = torch.constant.int 128 - %4346 = torch.prim.ListConstruct %437, %int32_5329, %int2_5330, %int32_5331, %int8_5332, %int128_5333 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4347 = torch.aten.view %4345, %4346 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4347, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_5334 = torch.constant.int 8 - %int128_5335 = torch.constant.int 128 - %4348 = torch.prim.ListConstruct %4337, %int8_5334, %int128_5335 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4349 = torch.aten.view %4347, %4348 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4349, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_5336 = torch.constant.int 32 - %4350 = torch.aten.floor_divide.Scalar %arg2, %int32_5336 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_5337 = torch.constant.int 1 - %4351 = torch.aten.unsqueeze %4350, %int1_5337 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5338 = torch.constant.int 1 - %false_5339 = torch.constant.bool false - %4352 = torch.aten.gather %arg3, %int1_5338, %4351, %false_5339 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_5340 = torch.constant.int 32 - %4353 = torch.aten.remainder.Scalar %arg2, %int32_5340 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %5070 = torch.aten.gather %arg3, %int1_5321, %5069, %false_5322 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_5323 = torch.constant.int 4 + %int1_5324 = torch.constant.int 1 + %int1_5325 = torch.constant.int 1 + %5071 = torch.prim.ListConstruct %int4_5323, %int1_5324, %int1_5325 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5072 = torch.aten.view %5070, %5071 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_5326 = torch.constant.int 32 + %5073 = torch.aten.remainder.Scalar %arg2, %int32_5326 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_5327 = torch.constant.int 4 + %int1_5328 = torch.constant.int 1 + %int1_5329 = torch.constant.int 1 + %5074 = torch.prim.ListConstruct %int4_5327, %int1_5328, %int1_5329 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5075 = torch.aten.view %5073, %5074 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_5330 = torch.constant.int 8 + %none_5331 = torch.constant.none + %none_5332 = torch.constant.none + %cpu_5333 = torch.constant.device "cpu" + %false_5334 = torch.constant.bool false + %5076 = torch.aten.arange %int8_5330, %none_5331, %none_5332, %cpu_5333, %false_5334 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_5335 = torch.constant.int 1 + %int1_5336 = torch.constant.int 1 + %int8_5337 = torch.constant.int 8 + %5077 = torch.prim.ListConstruct %int1_5335, %int1_5336, %int8_5337 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5078 = torch.aten.view %5076, %5077 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_5338 = torch.constant.none + %5079 = torch.aten.clone %300, %none_5338 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5080 = torch.aten.detach %5079 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5081 = torch.aten.detach %5080 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5082 = torch.aten.detach %5081 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_5339 = torch.constant.int 1 + %int1_5340 = torch.constant.int 1 %int1_5341 = torch.constant.int 1 - %4354 = torch.aten.unsqueeze %4353, %int1_5341 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_5342 = torch.constant.none - %4355 = torch.aten.clone %215, %none_5342 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_5343 = torch.constant.int 0 - %4356 = torch.aten.unsqueeze %4355, %int0_5343 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_5344 = torch.constant.int 4 + %5083 = torch.prim.ListConstruct %int1_5339, %int1_5340, %int1_5341 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5084 = torch.aten.view %5082, %5083 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_5342 = torch.constant.int 32 + %5085 = torch.aten.mul.Scalar %5072, %int32_5342 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int21 = torch.constant.int 21 + %int1_5343 = torch.constant.int 1 + %5086 = torch.aten.add.Scalar %5085, %int21, %int1_5343 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_5344 = torch.constant.int 2 + %5087 = torch.aten.mul.Scalar %5086, %int2_5344 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_5345 = torch.constant.int 1 - %4357 = torch.prim.ListConstruct %int4_5344, %int1_5345 : (!torch.int, !torch.int) -> !torch.list - %int1_5346 = torch.constant.int 1 + %5088 = torch.aten.add.Tensor %5087, %5084, %int1_5345 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_5346 = torch.constant.int 8 + %5089 = torch.aten.mul.Scalar %5088, %int8_5346 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_5347 = torch.constant.int 1 - %4358 = torch.prim.ListConstruct %int1_5346, %int1_5347 : (!torch.int, !torch.int) -> !torch.list - %int4_5348 = torch.constant.int 4 - %int0_5349 = torch.constant.int 0 - %cpu_5350 = torch.constant.device "cpu" - %false_5351 = torch.constant.bool false - %4359 = torch.aten.empty_strided %4357, %4358, %int4_5348, %int0_5349, %cpu_5350, %false_5351 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int19_5352 = torch.constant.int 19 - %4360 = torch.aten.fill.Scalar %4359, %int19_5352 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_5353 = torch.constant.int 4 - %int1_5354 = torch.constant.int 1 - %4361 = torch.prim.ListConstruct %int4_5353, %int1_5354 : (!torch.int, !torch.int) -> !torch.list - %4362 = torch.aten.repeat %4356, %4361 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_5355 = torch.constant.int 32 - %4363 = torch.aten.mul.Scalar %4352, %int32_5355 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5356 = torch.constant.int 1 - %4364 = torch.aten.add.Tensor %4363, %4360, %int1_5356 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_5357 = torch.constant.int 2 - %4365 = torch.aten.mul.Scalar %4364, %int2_5357 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5358 = torch.constant.int 1 - %4366 = torch.aten.add.Tensor %4365, %4362, %int1_5358 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_5359 = torch.constant.int 32 - %4367 = torch.aten.mul.Scalar %4366, %int32_5359 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5360 = torch.constant.int 1 - %4368 = torch.aten.add.Tensor %4367, %4354, %int1_5360 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %4369 = torch.prim.ListConstruct %4368 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_5361 = torch.constant.bool false - %4370 = torch.aten.index_put %4349, %4369, %4301, %false_5361 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4370, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_5362 = torch.constant.int 32 - %int2_5363 = torch.constant.int 2 + %5090 = torch.aten.add.Tensor %5089, %5078, %int1_5347 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_5348 = torch.constant.int 32 + %5091 = torch.aten.mul.Scalar %5090, %int32_5348 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_5349 = torch.constant.int 1 + %5092 = torch.aten.add.Tensor %5091, %5075, %int1_5349 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_5350 = torch.constant.int 5 + %5093 = torch.prims.convert_element_type %5067, %int5_5350 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_5351 = torch.constant.int 32 + %int2_5352 = torch.constant.int 2 + %int8_5353 = torch.constant.int 8 + %int32_5354 = torch.constant.int 32 + %int128_5355 = torch.constant.int 128 + %5094 = torch.prim.ListConstruct %456, %int32_5351, %int2_5352, %int8_5353, %int32_5354, %int128_5355 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5095 = torch.aten.view %4915, %5094 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5095, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_5356 = torch.constant.int 128 + %5096 = torch.prim.ListConstruct %596, %int128_5356 : (!torch.int, !torch.int) -> !torch.list + %5097 = torch.aten.view %5095, %5096 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5097, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %5098 = torch.prim.ListConstruct %5092 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_5357 = torch.constant.bool false + %5099 = torch.aten.index_put %5097, %5098, %5093, %false_5357 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5099, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_5358 = torch.constant.int 32 + %int2_5359 = torch.constant.int 2 + %int8_5360 = torch.constant.int 8 + %int32_5361 = torch.constant.int 32 + %int128_5362 = torch.constant.int 128 + %5100 = torch.prim.ListConstruct %456, %int32_5358, %int2_5359, %int8_5360, %int32_5361, %int128_5362 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5101 = torch.aten.view %5099, %5100 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5101, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_5363 = torch.constant.int 2097152 + %5102 = torch.prim.ListConstruct %456, %int2097152_5363 : (!torch.int, !torch.int) -> !torch.list + %5103 = torch.aten.view %5101, %5102 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5103, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> %int32_5364 = torch.constant.int 32 - %int8_5365 = torch.constant.int 8 - %int128_5366 = torch.constant.int 128 - %4371 = torch.prim.ListConstruct %437, %int32_5362, %int2_5363, %int32_5364, %int8_5365, %int128_5366 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4372 = torch.aten.view %4370, %4371 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4372, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5367 = torch.constant.int 2097152 - %4373 = torch.prim.ListConstruct %437, %int2097152_5367 : (!torch.int, !torch.int) -> !torch.list - %4374 = torch.aten.view %4372, %4373 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4374, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_5368 = torch.constant.int 4 - %4375 = torch.prim.ListConstruct %int4_5368, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_5369 = torch.constant.int 1 - %4376 = torch.prim.ListConstruct %358, %int1_5369 : (!torch.int, !torch.int) -> !torch.list - %int4_5370 = torch.constant.int 4 - %int0_5371 = torch.constant.int 0 - %cpu_5372 = torch.constant.device "cpu" - %false_5373 = torch.constant.bool false - %4377 = torch.aten.empty_strided %4375, %4376, %int4_5370, %int0_5371, %cpu_5372, %false_5373 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4377, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int19_5374 = torch.constant.int 19 - %4378 = torch.aten.fill.Scalar %4377, %int19_5374 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4378, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_5375 = torch.constant.int 32 - %4379 = torch.aten.mul.Scalar %arg3, %int32_5375 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4379, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int2_5365 = torch.constant.int 2 + %int8_5366 = torch.constant.int 8 + %int32_5367 = torch.constant.int 32 + %int128_5368 = torch.constant.int 128 + %5104 = torch.prim.ListConstruct %456, %int32_5364, %int2_5365, %int8_5366, %int32_5367, %int128_5368 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5105 = torch.aten.view %5103, %5104 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5105, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_5369 = torch.constant.int 128 + %5106 = torch.prim.ListConstruct %596, %int128_5369 : (!torch.int, !torch.int) -> !torch.list + %5107 = torch.aten.view %5105, %5106 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5107, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_5370 = torch.constant.none + %5108 = torch.aten.clone %301, %none_5370 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5109 = torch.aten.detach %5108 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5110 = torch.aten.detach %5109 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5111 = torch.aten.detach %5110 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_5371 = torch.constant.int 1 + %int1_5372 = torch.constant.int 1 + %int1_5373 = torch.constant.int 1 + %5112 = torch.prim.ListConstruct %int1_5371, %int1_5372, %int1_5373 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5113 = torch.aten.view %5111, %5112 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_5374 = torch.constant.int 32 + %5114 = torch.aten.mul.Scalar %5072, %int32_5374 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int21_5375 = torch.constant.int 21 %int1_5376 = torch.constant.int 1 - %4380 = torch.aten.add.Tensor %4379, %4378, %int1_5376 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4380, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_5377 = torch.constant.int 4 - %4381 = torch.aten.mul.int %int4_5377, %358 : !torch.int, !torch.int -> !torch.int - %4382 = torch.prim.ListConstruct %4381 : (!torch.int) -> !torch.list - %4383 = torch.aten.view %4380, %4382 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4383, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_5378 = torch.constant.int 32 - %int2_5379 = torch.constant.int 2 - %int32_5380 = torch.constant.int 32 - %int8_5381 = torch.constant.int 8 - %int128_5382 = torch.constant.int 128 - %4384 = torch.prim.ListConstruct %437, %int32_5378, %int2_5379, %int32_5380, %int8_5381, %int128_5382 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4385 = torch.aten.view %4374, %4384 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4385, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5383 = torch.constant.int 32 - %4386 = torch.aten.mul.int %437, %int32_5383 : !torch.int, !torch.int -> !torch.int - %int2_5384 = torch.constant.int 2 + %5115 = torch.aten.add.Scalar %5114, %int21_5375, %int1_5376 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_5377 = torch.constant.int 2 + %5116 = torch.aten.mul.Scalar %5115, %int2_5377 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_5378 = torch.constant.int 1 + %5117 = torch.aten.add.Tensor %5116, %5113, %int1_5378 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_5379 = torch.constant.int 8 + %5118 = torch.aten.mul.Scalar %5117, %int8_5379 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_5380 = torch.constant.int 1 + %5119 = torch.aten.add.Tensor %5118, %5078, %int1_5380 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_5381 = torch.constant.int 32 + %5120 = torch.aten.mul.Scalar %5119, %int32_5381 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_5382 = torch.constant.int 1 + %5121 = torch.aten.add.Tensor %5120, %5075, %int1_5382 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_5383 = torch.constant.int 5 + %5122 = torch.prims.convert_element_type %5047, %int5_5383 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %5123 = torch.prim.ListConstruct %5121 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_5384 = torch.constant.bool false + %5124 = torch.aten.index_put %5107, %5123, %5122, %false_5384 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5124, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> %int32_5385 = torch.constant.int 32 - %int8_5386 = torch.constant.int 8 - %int128_5387 = torch.constant.int 128 - %4387 = torch.prim.ListConstruct %4386, %int2_5384, %int32_5385, %int8_5386, %int128_5387 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4388 = torch.aten.view %4385, %4387 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %4388, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_5388 = torch.constant.int 0 - %4389 = torch.aten.index_select %4388, %int0_5388, %4383 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %4389, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_5389 = torch.constant.int 4 - %int2_5390 = torch.constant.int 2 - %int32_5391 = torch.constant.int 32 - %int8_5392 = torch.constant.int 8 - %int128_5393 = torch.constant.int 128 - %4390 = torch.prim.ListConstruct %int4_5389, %358, %int2_5390, %int32_5391, %int8_5392, %int128_5393 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4391 = torch.aten.view %4389, %4390 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4391, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_5394 = torch.constant.int 0 - %int0_5395 = torch.constant.int 0 - %int9223372036854775807_5396 = torch.constant.int 9223372036854775807 - %int1_5397 = torch.constant.int 1 - %4392 = torch.aten.slice.Tensor %4391, %int0_5394, %int0_5395, %int9223372036854775807_5396, %int1_5397 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4392, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_5398 = torch.constant.int 1 - %int0_5399 = torch.constant.int 0 - %int9223372036854775807_5400 = torch.constant.int 9223372036854775807 - %int1_5401 = torch.constant.int 1 - %4393 = torch.aten.slice.Tensor %4392, %int1_5398, %int0_5399, %int9223372036854775807_5400, %int1_5401 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4393, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_5402 = torch.constant.int 2 - %int0_5403 = torch.constant.int 0 - %4394 = torch.aten.select.int %4393, %int2_5402, %int0_5403 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4394, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_5404 = torch.constant.int 32 - %4395 = torch.aten.mul.int %358, %int32_5404 : !torch.int, !torch.int -> !torch.int - %int2_5405 = torch.constant.int 2 - %int0_5406 = torch.constant.int 0 - %int1_5407 = torch.constant.int 1 - %4396 = torch.aten.slice.Tensor %4394, %int2_5405, %int0_5406, %4395, %int1_5407 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4396, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_5408 = torch.constant.int 0 - %4397 = torch.aten.clone %4396, %int0_5408 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4397, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_5409 = torch.constant.int 1 - %4398 = torch.aten.size.int %4393, %int1_5409 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_5410 = torch.constant.int 32 - %4399 = torch.aten.mul.int %4398, %int32_5410 : !torch.int, !torch.int -> !torch.int - %int4_5411 = torch.constant.int 4 - %int8_5412 = torch.constant.int 8 - %int128_5413 = torch.constant.int 128 - %4400 = torch.prim.ListConstruct %int4_5411, %4399, %int8_5412, %int128_5413 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4401 = torch.aten._unsafe_view %4397, %4400 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4401, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_5414 = torch.constant.int 0 - %int0_5415 = torch.constant.int 0 - %int9223372036854775807_5416 = torch.constant.int 9223372036854775807 - %int1_5417 = torch.constant.int 1 - %4402 = torch.aten.slice.Tensor %4401, %int0_5414, %int0_5415, %int9223372036854775807_5416, %int1_5417 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4402, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_5418 = torch.constant.int 0 - %int0_5419 = torch.constant.int 0 - %int9223372036854775807_5420 = torch.constant.int 9223372036854775807 - %int1_5421 = torch.constant.int 1 - %4403 = torch.aten.slice.Tensor %4391, %int0_5418, %int0_5419, %int9223372036854775807_5420, %int1_5421 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4403, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_5422 = torch.constant.int 1 - %int0_5423 = torch.constant.int 0 - %int9223372036854775807_5424 = torch.constant.int 9223372036854775807 - %int1_5425 = torch.constant.int 1 - %4404 = torch.aten.slice.Tensor %4403, %int1_5422, %int0_5423, %int9223372036854775807_5424, %int1_5425 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4404, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_5426 = torch.constant.int 2 - %int1_5427 = torch.constant.int 1 - %4405 = torch.aten.select.int %4404, %int2_5426, %int1_5427 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4405, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_5428 = torch.constant.int 2 - %int0_5429 = torch.constant.int 0 - %int1_5430 = torch.constant.int 1 - %4406 = torch.aten.slice.Tensor %4405, %int2_5428, %int0_5429, %4395, %int1_5430 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4406, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int2_5386 = torch.constant.int 2 + %int8_5387 = torch.constant.int 8 + %int32_5388 = torch.constant.int 32 + %int128_5389 = torch.constant.int 128 + %5125 = torch.prim.ListConstruct %456, %int32_5385, %int2_5386, %int8_5387, %int32_5388, %int128_5389 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5126 = torch.aten.view %5124, %5125 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5126, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_5390 = torch.constant.int 2097152 + %5127 = torch.prim.ListConstruct %456, %int2097152_5390 : (!torch.int, !torch.int) -> !torch.list + %5128 = torch.aten.view %5126, %5127 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5128, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_5391 = torch.constant.none + %5129 = torch.aten.clone %302, %none_5391 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5130 = torch.aten.detach %5129 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5131 = torch.aten.detach %5130 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5132 = torch.aten.detach %5131 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_5392 = torch.constant.none + %5133 = torch.aten.clone %303, %none_5392 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5134 = torch.aten.detach %5133 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5135 = torch.aten.detach %5134 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5136 = torch.aten.detach %5135 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_5393 = torch.constant.none + %5137 = torch.aten.clone %304, %none_5393 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5138 = torch.aten.detach %5137 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5139 = torch.aten.detach %5138 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5140 = torch.aten.detach %5139 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_5394 = torch.constant.int 32 + %int2_5395 = torch.constant.int 2 + %int8_5396 = torch.constant.int 8 + %int32_5397 = torch.constant.int 32 + %int128_5398 = torch.constant.int 128 + %5141 = torch.prim.ListConstruct %456, %int32_5394, %int2_5395, %int8_5396, %int32_5397, %int128_5398 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5142 = torch.aten.view %5128, %5141 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5142, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %5143 = torch_c.to_builtin_tensor %5142 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %5144 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_5399 = tensor.cast %5144 : tensor<4x?xi64> to tensor + %5145 = torch_c.to_builtin_tensor %5132 : !torch.vtensor<[],si64> -> tensor + %5146 = torch_c.to_builtin_tensor %5136 : !torch.vtensor<[],si64> -> tensor + %5147 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%5143, %cast_5399, %5145, %5146) : (tensor, tensor, tensor, tensor) -> tensor + %cast_5400 = tensor.cast %5147 : tensor to tensor<4x?x8x32x128xf16> + %5148 = torch_c.from_builtin_tensor %cast_5400 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %5148, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %5149 = torch_c.to_builtin_tensor %5142 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %5150 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_5401 = tensor.cast %5150 : tensor<4x?xi64> to tensor + %5151 = torch_c.to_builtin_tensor %5132 : !torch.vtensor<[],si64> -> tensor + %5152 = torch_c.to_builtin_tensor %5140 : !torch.vtensor<[],si64> -> tensor + %5153 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%5149, %cast_5401, %5151, %5152) : (tensor, tensor, tensor, tensor) -> tensor + %cast_5402 = tensor.cast %5153 : tensor to tensor<4x?x8x32x128xf16> + %5154 = torch_c.from_builtin_tensor %cast_5402 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %5154, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_5403 = torch.constant.int 2 + %int3_5404 = torch.constant.int 3 + %5155 = torch.aten.transpose.int %5148, %int2_5403, %int3_5404 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5155, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_5405 = torch.constant.int 0 + %5156 = torch.aten.clone %5155, %int0_5405 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5156, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_5406 = torch.constant.int 4 + %int8_5407 = torch.constant.int 8 + %int128_5408 = torch.constant.int 128 + %5157 = torch.prim.ListConstruct %int4_5406, %457, %int8_5407, %int128_5408 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5158 = torch.aten._unsafe_view %5156, %5157 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5158, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_5409 = torch.constant.int 2 + %int3_5410 = torch.constant.int 3 + %5159 = torch.aten.transpose.int %5154, %int2_5409, %int3_5410 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5159, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_5411 = torch.constant.int 0 + %5160 = torch.aten.clone %5159, %int0_5411 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5160, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_5412 = torch.constant.int 4 + %int8_5413 = torch.constant.int 8 + %int128_5414 = torch.constant.int 128 + %5161 = torch.prim.ListConstruct %int4_5412, %457, %int8_5413, %int128_5414 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5162 = torch.aten._unsafe_view %5160, %5161 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5162, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_5415 = torch.constant.int -2 + %5163 = torch.aten.unsqueeze %5158, %int-2_5415 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5163, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_5416 = torch.constant.int 4 + %int8_5417 = torch.constant.int 8 + %int4_5418 = torch.constant.int 4 + %int128_5419 = torch.constant.int 128 + %5164 = torch.prim.ListConstruct %int4_5416, %457, %int8_5417, %int4_5418, %int128_5419 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_5420 = torch.constant.bool false + %5165 = torch.aten.expand %5163, %5164, %false_5420 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5165, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_5421 = torch.constant.int 0 + %5166 = torch.aten.clone %5165, %int0_5421 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5166, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_5422 = torch.constant.int 4 + %int32_5423 = torch.constant.int 32 + %int128_5424 = torch.constant.int 128 + %5167 = torch.prim.ListConstruct %int4_5422, %457, %int32_5423, %int128_5424 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5168 = torch.aten._unsafe_view %5166, %5167 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5168, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_5425 = torch.constant.int -2 + %5169 = torch.aten.unsqueeze %5162, %int-2_5425 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5169, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_5426 = torch.constant.int 4 + %int8_5427 = torch.constant.int 8 + %int4_5428 = torch.constant.int 4 + %int128_5429 = torch.constant.int 128 + %5170 = torch.prim.ListConstruct %int4_5426, %457, %int8_5427, %int4_5428, %int128_5429 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_5430 = torch.constant.bool false + %5171 = torch.aten.expand %5169, %5170, %false_5430 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5171, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int0_5431 = torch.constant.int 0 - %4407 = torch.aten.clone %4406, %int0_5431 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4407, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_5432 = torch.constant.int 1 - %4408 = torch.aten.size.int %4404, %int1_5432 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int + %5172 = torch.aten.clone %5171, %int0_5431 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5172, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_5432 = torch.constant.int 4 %int32_5433 = torch.constant.int 32 - %4409 = torch.aten.mul.int %4408, %int32_5433 : !torch.int, !torch.int -> !torch.int - %int4_5434 = torch.constant.int 4 - %int8_5435 = torch.constant.int 8 - %int128_5436 = torch.constant.int 128 - %4410 = torch.prim.ListConstruct %int4_5434, %4409, %int8_5435, %int128_5436 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4411 = torch.aten._unsafe_view %4407, %4410 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4411, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_5437 = torch.constant.int 0 - %int0_5438 = torch.constant.int 0 - %int9223372036854775807_5439 = torch.constant.int 9223372036854775807 - %int1_5440 = torch.constant.int 1 - %4412 = torch.aten.slice.Tensor %4411, %int0_5437, %int0_5438, %int9223372036854775807_5439, %int1_5440 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4412, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_5441 = torch.constant.int -2 - %4413 = torch.aten.unsqueeze %4402, %int-2_5441 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4413, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_5442 = torch.constant.int 1 - %4414 = torch.aten.size.int %4401, %int1_5442 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_5443 = torch.constant.int 4 - %int8_5444 = torch.constant.int 8 - %int4_5445 = torch.constant.int 4 - %int128_5446 = torch.constant.int 128 - %4415 = torch.prim.ListConstruct %int4_5443, %4414, %int8_5444, %int4_5445, %int128_5446 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_5447 = torch.constant.bool false - %4416 = torch.aten.expand %4413, %4415, %false_5447 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4416, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_5448 = torch.constant.int 0 - %4417 = torch.aten.clone %4416, %int0_5448 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4417, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_5449 = torch.constant.int 4 - %int32_5450 = torch.constant.int 32 - %int128_5451 = torch.constant.int 128 - %4418 = torch.prim.ListConstruct %int4_5449, %4414, %int32_5450, %int128_5451 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4419 = torch.aten._unsafe_view %4417, %4418 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4419, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_5452 = torch.constant.int -2 - %4420 = torch.aten.unsqueeze %4412, %int-2_5452 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4420, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_5453 = torch.constant.int 1 - %4421 = torch.aten.size.int %4411, %int1_5453 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int + %int128_5434 = torch.constant.int 128 + %5173 = torch.prim.ListConstruct %int4_5432, %457, %int32_5433, %int128_5434 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5174 = torch.aten._unsafe_view %5172, %5173 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5174, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_5435 = torch.constant.int 1 + %int2_5436 = torch.constant.int 2 + %5175 = torch.aten.transpose.int %5057, %int1_5435, %int2_5436 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_5437 = torch.constant.int 1 + %int2_5438 = torch.constant.int 2 + %5176 = torch.aten.transpose.int %5168, %int1_5437, %int2_5438 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5176, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_5439 = torch.constant.int 1 + %int2_5440 = torch.constant.int 2 + %5177 = torch.aten.transpose.int %5174, %int1_5439, %int2_5440 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5177, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_5441 = torch.constant.float 0.000000e+00 + %false_5442 = torch.constant.bool false + %none_5443 = torch.constant.none + %5178:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5175, %5176, %5177, %float0.000000e00_5441, %false_5442, %470, %none_5443) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_5444 = torch.constant.int 1 + %int2_5445 = torch.constant.int 2 + %5179 = torch.aten.transpose.int %5178#0, %int1_5444, %int2_5445 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_5446 = torch.constant.int 4 + %int1_5447 = torch.constant.int 1 + %int4096_5448 = torch.constant.int 4096 + %5180 = torch.prim.ListConstruct %int4_5446, %int1_5447, %int4096_5448 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5181 = torch.aten.view %5179, %5180 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_5449 = torch.constant.int -2 + %int-1_5450 = torch.constant.int -1 + %5182 = torch.aten.transpose.int %305, %int-2_5449, %int-1_5450 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_5451 = torch.constant.int 5 + %5183 = torch.prims.convert_element_type %5182, %int5_5451 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_5452 = torch.constant.int 4 + %int4096_5453 = torch.constant.int 4096 + %5184 = torch.prim.ListConstruct %int4_5452, %int4096_5453 : (!torch.int, !torch.int) -> !torch.list + %5185 = torch.aten.view %5181, %5184 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5186 = torch.aten.mm %5185, %5183 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_5454 = torch.constant.int 4 - %int8_5455 = torch.constant.int 8 - %int4_5456 = torch.constant.int 4 - %int128_5457 = torch.constant.int 128 - %4422 = torch.prim.ListConstruct %int4_5454, %4421, %int8_5455, %int4_5456, %int128_5457 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_5458 = torch.constant.bool false - %4423 = torch.aten.expand %4420, %4422, %false_5458 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4423, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_5459 = torch.constant.int 0 - %4424 = torch.aten.clone %4423, %int0_5459 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4424, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_5460 = torch.constant.int 4 - %int32_5461 = torch.constant.int 32 - %int128_5462 = torch.constant.int 128 - %4425 = torch.prim.ListConstruct %int4_5460, %4421, %int32_5461, %int128_5462 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4426 = torch.aten._unsafe_view %4424, %4425 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4426, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_5463 = torch.constant.int 1 - %int2_5464 = torch.constant.int 2 - %4427 = torch.aten.transpose.int %4307, %int1_5463, %int2_5464 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_5465 = torch.constant.int 1 - %int2_5466 = torch.constant.int 2 - %4428 = torch.aten.transpose.int %4419, %int1_5465, %int2_5466 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4428, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_5467 = torch.constant.int 1 - %int2_5468 = torch.constant.int 2 - %4429 = torch.aten.transpose.int %4426, %int1_5467, %int2_5468 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4429, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_5469 = torch.constant.float 0.000000e+00 - %false_5470 = torch.constant.bool false - %none_5471 = torch.constant.none - %4430:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4427, %4428, %4429, %float0.000000e00_5469, %false_5470, %368, %none_5471) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_5472 = torch.constant.int 1 - %int2_5473 = torch.constant.int 2 - %4431 = torch.aten.transpose.int %4430#0, %int1_5472, %int2_5473 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_5474 = torch.constant.int 4 - %int1_5475 = torch.constant.int 1 - %int4096_5476 = torch.constant.int 4096 - %4432 = torch.prim.ListConstruct %int4_5474, %int1_5475, %int4096_5476 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4433 = torch.aten.view %4431, %4432 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_5477 = torch.constant.int -2 - %int-1_5478 = torch.constant.int -1 - %4434 = torch.aten.transpose.int %216, %int-2_5477, %int-1_5478 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_5479 = torch.constant.int 4 - %int4096_5480 = torch.constant.int 4096 - %4435 = torch.prim.ListConstruct %int4_5479, %int4096_5480 : (!torch.int, !torch.int) -> !torch.list - %4436 = torch.aten.view %4433, %4435 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4437 = torch.aten.mm %4436, %4434 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_5481 = torch.constant.int 4 - %int1_5482 = torch.constant.int 1 - %int4096_5483 = torch.constant.int 4096 - %4438 = torch.prim.ListConstruct %int4_5481, %int1_5482, %int4096_5483 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4439 = torch.aten.view %4437, %4438 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_5484 = torch.constant.int 1 - %4440 = torch.aten.add.Tensor %4267, %4439, %int1_5484 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_5485 = torch.constant.int 6 - %4441 = torch.prims.convert_element_type %4440, %int6_5485 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_5486 = torch.constant.int 2 - %4442 = torch.aten.pow.Tensor_Scalar %4441, %int2_5486 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_5487 = torch.constant.int -1 - %4443 = torch.prim.ListConstruct %int-1_5487 : (!torch.int) -> !torch.list - %true_5488 = torch.constant.bool true - %none_5489 = torch.constant.none - %4444 = torch.aten.mean.dim %4442, %4443, %true_5488, %none_5489 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_5490 = torch.constant.float 9.9999997473787516E-6 + %int1_5455 = torch.constant.int 1 + %int4096_5456 = torch.constant.int 4096 + %5187 = torch.prim.ListConstruct %int4_5454, %int1_5455, %int4096_5456 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5188 = torch.aten.view %5186, %5187 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_5457 = torch.constant.int 1 + %5189 = torch.aten.add.Tensor %5010, %5188, %int1_5457 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_5458 = torch.constant.int 6 + %5190 = torch.prims.convert_element_type %5189, %int6_5458 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_5459 = torch.constant.int 2 + %5191 = torch.aten.pow.Tensor_Scalar %5190, %int2_5459 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_5460 = torch.constant.int -1 + %5192 = torch.prim.ListConstruct %int-1_5460 : (!torch.int) -> !torch.list + %true_5461 = torch.constant.bool true + %none_5462 = torch.constant.none + %5193 = torch.aten.mean.dim %5191, %5192, %true_5461, %none_5462 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_5463 = torch.constant.float 9.9999997473787516E-6 + %int1_5464 = torch.constant.int 1 + %5194 = torch.aten.add.Scalar %5193, %float9.999990e-06_5463, %int1_5464 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %5195 = torch.aten.rsqrt %5194 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %5196 = torch.aten.mul.Tensor %5190, %5195 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_5465 = torch.constant.int 5 + %5197 = torch.prims.convert_element_type %5196, %int5_5465 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %5198 = torch.aten.mul.Tensor %306, %5197 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_5466 = torch.constant.int 5 + %5199 = torch.prims.convert_element_type %5198, %int5_5466 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_5467 = torch.constant.int -2 + %int-1_5468 = torch.constant.int -1 + %5200 = torch.aten.transpose.int %307, %int-2_5467, %int-1_5468 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_5469 = torch.constant.int 5 + %5201 = torch.prims.convert_element_type %5200, %int5_5469 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_5470 = torch.constant.int 4 + %int4096_5471 = torch.constant.int 4096 + %5202 = torch.prim.ListConstruct %int4_5470, %int4096_5471 : (!torch.int, !torch.int) -> !torch.list + %5203 = torch.aten.view %5199, %5202 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5204 = torch.aten.mm %5203, %5201 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_5472 = torch.constant.int 4 + %int1_5473 = torch.constant.int 1 + %int14336_5474 = torch.constant.int 14336 + %5205 = torch.prim.ListConstruct %int4_5472, %int1_5473, %int14336_5474 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5206 = torch.aten.view %5204, %5205 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %5207 = torch.aten.silu %5206 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_5475 = torch.constant.int -2 + %int-1_5476 = torch.constant.int -1 + %5208 = torch.aten.transpose.int %308, %int-2_5475, %int-1_5476 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_5477 = torch.constant.int 5 + %5209 = torch.prims.convert_element_type %5208, %int5_5477 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_5478 = torch.constant.int 4 + %int4096_5479 = torch.constant.int 4096 + %5210 = torch.prim.ListConstruct %int4_5478, %int4096_5479 : (!torch.int, !torch.int) -> !torch.list + %5211 = torch.aten.view %5199, %5210 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5212 = torch.aten.mm %5211, %5209 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_5480 = torch.constant.int 4 + %int1_5481 = torch.constant.int 1 + %int14336_5482 = torch.constant.int 14336 + %5213 = torch.prim.ListConstruct %int4_5480, %int1_5481, %int14336_5482 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5214 = torch.aten.view %5212, %5213 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %5215 = torch.aten.mul.Tensor %5207, %5214 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_5483 = torch.constant.int -2 + %int-1_5484 = torch.constant.int -1 + %5216 = torch.aten.transpose.int %309, %int-2_5483, %int-1_5484 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_5485 = torch.constant.int 5 + %5217 = torch.prims.convert_element_type %5216, %int5_5485 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_5486 = torch.constant.int 4 + %int14336_5487 = torch.constant.int 14336 + %5218 = torch.prim.ListConstruct %int4_5486, %int14336_5487 : (!torch.int, !torch.int) -> !torch.list + %5219 = torch.aten.view %5215, %5218 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %5220 = torch.aten.mm %5219, %5217 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_5488 = torch.constant.int 4 + %int1_5489 = torch.constant.int 1 + %int4096_5490 = torch.constant.int 4096 + %5221 = torch.prim.ListConstruct %int4_5488, %int1_5489, %int4096_5490 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5222 = torch.aten.view %5220, %5221 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_5491 = torch.constant.int 1 - %4445 = torch.aten.add.Scalar %4444, %float9.999990e-06_5490, %int1_5491 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %4446 = torch.aten.rsqrt %4445 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %4447 = torch.aten.mul.Tensor %4441, %4446 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_5492 = torch.constant.int 5 - %4448 = torch.prims.convert_element_type %4447, %int5_5492 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %4449 = torch.aten.mul.Tensor %217, %4448 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_5493 = torch.constant.int 5 - %4450 = torch.prims.convert_element_type %4449, %int5_5493 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_5494 = torch.constant.int -2 - %int-1_5495 = torch.constant.int -1 - %4451 = torch.aten.transpose.int %218, %int-2_5494, %int-1_5495 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_5496 = torch.constant.int 4 - %int4096_5497 = torch.constant.int 4096 - %4452 = torch.prim.ListConstruct %int4_5496, %int4096_5497 : (!torch.int, !torch.int) -> !torch.list - %4453 = torch.aten.view %4450, %4452 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4454 = torch.aten.mm %4453, %4451 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_5498 = torch.constant.int 4 - %int1_5499 = torch.constant.int 1 - %int14336_5500 = torch.constant.int 14336 - %4455 = torch.prim.ListConstruct %int4_5498, %int1_5499, %int14336_5500 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4456 = torch.aten.view %4454, %4455 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %4457 = torch.aten.silu %4456 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %5223 = torch.aten.add.Tensor %5189, %5222, %int1_5491 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_5492 = torch.constant.int 6 + %5224 = torch.prims.convert_element_type %5223, %int6_5492 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_5493 = torch.constant.int 2 + %5225 = torch.aten.pow.Tensor_Scalar %5224, %int2_5493 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_5494 = torch.constant.int -1 + %5226 = torch.prim.ListConstruct %int-1_5494 : (!torch.int) -> !torch.list + %true_5495 = torch.constant.bool true + %none_5496 = torch.constant.none + %5227 = torch.aten.mean.dim %5225, %5226, %true_5495, %none_5496 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_5497 = torch.constant.float 9.9999997473787516E-6 + %int1_5498 = torch.constant.int 1 + %5228 = torch.aten.add.Scalar %5227, %float9.999990e-06_5497, %int1_5498 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %5229 = torch.aten.rsqrt %5228 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %5230 = torch.aten.mul.Tensor %5224, %5229 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_5499 = torch.constant.int 5 + %5231 = torch.prims.convert_element_type %5230, %int5_5499 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %5232 = torch.aten.mul.Tensor %310, %5231 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_5500 = torch.constant.int 5 + %5233 = torch.prims.convert_element_type %5232, %int5_5500 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> %int-2_5501 = torch.constant.int -2 %int-1_5502 = torch.constant.int -1 - %4458 = torch.aten.transpose.int %219, %int-2_5501, %int-1_5502 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_5503 = torch.constant.int 4 - %int4096_5504 = torch.constant.int 4096 - %4459 = torch.prim.ListConstruct %int4_5503, %int4096_5504 : (!torch.int, !torch.int) -> !torch.list - %4460 = torch.aten.view %4450, %4459 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4461 = torch.aten.mm %4460, %4458 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_5505 = torch.constant.int 4 - %int1_5506 = torch.constant.int 1 - %int14336_5507 = torch.constant.int 14336 - %4462 = torch.prim.ListConstruct %int4_5505, %int1_5506, %int14336_5507 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4463 = torch.aten.view %4461, %4462 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %4464 = torch.aten.mul.Tensor %4457, %4463 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_5508 = torch.constant.int -2 - %int-1_5509 = torch.constant.int -1 - %4465 = torch.aten.transpose.int %220, %int-2_5508, %int-1_5509 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_5510 = torch.constant.int 4 - %int14336_5511 = torch.constant.int 14336 - %4466 = torch.prim.ListConstruct %int4_5510, %int14336_5511 : (!torch.int, !torch.int) -> !torch.list - %4467 = torch.aten.view %4464, %4466 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %4468 = torch.aten.mm %4467, %4465 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %5234 = torch.aten.transpose.int %311, %int-2_5501, %int-1_5502 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_5503 = torch.constant.int 5 + %5235 = torch.prims.convert_element_type %5234, %int5_5503 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_5504 = torch.constant.int 4 + %int4096_5505 = torch.constant.int 4096 + %5236 = torch.prim.ListConstruct %int4_5504, %int4096_5505 : (!torch.int, !torch.int) -> !torch.list + %5237 = torch.aten.view %5233, %5236 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5238 = torch.aten.mm %5237, %5235 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_5506 = torch.constant.int 4 + %int1_5507 = torch.constant.int 1 + %int4096_5508 = torch.constant.int 4096 + %5239 = torch.prim.ListConstruct %int4_5506, %int1_5507, %int4096_5508 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5240 = torch.aten.view %5238, %5239 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_5509 = torch.constant.int -2 + %int-1_5510 = torch.constant.int -1 + %5241 = torch.aten.transpose.int %312, %int-2_5509, %int-1_5510 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_5511 = torch.constant.int 5 + %5242 = torch.prims.convert_element_type %5241, %int5_5511 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> %int4_5512 = torch.constant.int 4 - %int1_5513 = torch.constant.int 1 - %int4096_5514 = torch.constant.int 4096 - %4469 = torch.prim.ListConstruct %int4_5512, %int1_5513, %int4096_5514 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4470 = torch.aten.view %4468, %4469 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int4096_5513 = torch.constant.int 4096 + %5243 = torch.prim.ListConstruct %int4_5512, %int4096_5513 : (!torch.int, !torch.int) -> !torch.list + %5244 = torch.aten.view %5233, %5243 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5245 = torch.aten.mm %5244, %5242 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_5514 = torch.constant.int 4 %int1_5515 = torch.constant.int 1 - %4471 = torch.aten.add.Tensor %4440, %4470, %int1_5515 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_5516 = torch.constant.int 6 - %4472 = torch.prims.convert_element_type %4471, %int6_5516 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_5517 = torch.constant.int 2 - %4473 = torch.aten.pow.Tensor_Scalar %4472, %int2_5517 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int1024_5516 = torch.constant.int 1024 + %5246 = torch.prim.ListConstruct %int4_5514, %int1_5515, %int1024_5516 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5247 = torch.aten.view %5245, %5246 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_5517 = torch.constant.int -2 %int-1_5518 = torch.constant.int -1 - %4474 = torch.prim.ListConstruct %int-1_5518 : (!torch.int) -> !torch.list - %true_5519 = torch.constant.bool true - %none_5520 = torch.constant.none - %4475 = torch.aten.mean.dim %4473, %4474, %true_5519, %none_5520 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_5521 = torch.constant.float 9.9999997473787516E-6 - %int1_5522 = torch.constant.int 1 - %4476 = torch.aten.add.Scalar %4475, %float9.999990e-06_5521, %int1_5522 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %4477 = torch.aten.rsqrt %4476 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %4478 = torch.aten.mul.Tensor %4472, %4477 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_5523 = torch.constant.int 5 - %4479 = torch.prims.convert_element_type %4478, %int5_5523 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %4480 = torch.aten.mul.Tensor %221, %4479 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_5524 = torch.constant.int 5 - %4481 = torch.prims.convert_element_type %4480, %int5_5524 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_5525 = torch.constant.int -2 - %int-1_5526 = torch.constant.int -1 - %4482 = torch.aten.transpose.int %222, %int-2_5525, %int-1_5526 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_5527 = torch.constant.int 4 - %int4096_5528 = torch.constant.int 4096 - %4483 = torch.prim.ListConstruct %int4_5527, %int4096_5528 : (!torch.int, !torch.int) -> !torch.list - %4484 = torch.aten.view %4481, %4483 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4485 = torch.aten.mm %4484, %4482 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %5248 = torch.aten.transpose.int %313, %int-2_5517, %int-1_5518 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_5519 = torch.constant.int 5 + %5249 = torch.prims.convert_element_type %5248, %int5_5519 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_5520 = torch.constant.int 4 + %int4096_5521 = torch.constant.int 4096 + %5250 = torch.prim.ListConstruct %int4_5520, %int4096_5521 : (!torch.int, !torch.int) -> !torch.list + %5251 = torch.aten.view %5233, %5250 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5252 = torch.aten.mm %5251, %5249 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_5522 = torch.constant.int 4 + %int1_5523 = torch.constant.int 1 + %int1024_5524 = torch.constant.int 1024 + %5253 = torch.prim.ListConstruct %int4_5522, %int1_5523, %int1024_5524 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5254 = torch.aten.view %5252, %5253 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_5525 = torch.constant.int 4 + %int1_5526 = torch.constant.int 1 + %int32_5527 = torch.constant.int 32 + %int128_5528 = torch.constant.int 128 + %5255 = torch.prim.ListConstruct %int4_5525, %int1_5526, %int32_5527, %int128_5528 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5256 = torch.aten.view %5240, %5255 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> %int4_5529 = torch.constant.int 4 %int1_5530 = torch.constant.int 1 - %int4096_5531 = torch.constant.int 4096 - %4486 = torch.prim.ListConstruct %int4_5529, %int1_5530, %int4096_5531 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4487 = torch.aten.view %4485, %4486 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_5532 = torch.constant.int -2 - %int-1_5533 = torch.constant.int -1 - %4488 = torch.aten.transpose.int %223, %int-2_5532, %int-1_5533 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_5534 = torch.constant.int 4 - %int4096_5535 = torch.constant.int 4096 - %4489 = torch.prim.ListConstruct %int4_5534, %int4096_5535 : (!torch.int, !torch.int) -> !torch.list - %4490 = torch.aten.view %4481, %4489 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4491 = torch.aten.mm %4490, %4488 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_5536 = torch.constant.int 4 + %int8_5531 = torch.constant.int 8 + %int128_5532 = torch.constant.int 128 + %5257 = torch.prim.ListConstruct %int4_5529, %int1_5530, %int8_5531, %int128_5532 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5258 = torch.aten.view %5247, %5257 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_5533 = torch.constant.int 4 + %int1_5534 = torch.constant.int 1 + %int8_5535 = torch.constant.int 8 + %int128_5536 = torch.constant.int 128 + %5259 = torch.prim.ListConstruct %int4_5533, %int1_5534, %int8_5535, %int128_5536 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5260 = torch.aten.view %5254, %5259 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> %int1_5537 = torch.constant.int 1 - %int1024_5538 = torch.constant.int 1024 - %4492 = torch.prim.ListConstruct %int4_5536, %int1_5537, %int1024_5538 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4493 = torch.aten.view %4491, %4492 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_5539 = torch.constant.int -2 - %int-1_5540 = torch.constant.int -1 - %4494 = torch.aten.transpose.int %224, %int-2_5539, %int-1_5540 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_5541 = torch.constant.int 4 - %int4096_5542 = torch.constant.int 4096 - %4495 = torch.prim.ListConstruct %int4_5541, %int4096_5542 : (!torch.int, !torch.int) -> !torch.list - %4496 = torch.aten.view %4481, %4495 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4497 = torch.aten.mm %4496, %4494 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_5543 = torch.constant.int 4 - %int1_5544 = torch.constant.int 1 - %int1024_5545 = torch.constant.int 1024 - %4498 = torch.prim.ListConstruct %int4_5543, %int1_5544, %int1024_5545 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4499 = torch.aten.view %4497, %4498 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_5546 = torch.constant.int 4 - %int1_5547 = torch.constant.int 1 - %int32_5548 = torch.constant.int 32 - %int128_5549 = torch.constant.int 128 - %4500 = torch.prim.ListConstruct %int4_5546, %int1_5547, %int32_5548, %int128_5549 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4501 = torch.aten.view %4487, %4500 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_5550 = torch.constant.int 4 + %int2_5538 = torch.constant.int 2 + %5261 = torch.aten.transpose.int %5256, %int1_5537, %int2_5538 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %5262 = torch.aten.mul.Tensor %5261, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_5539 = torch.constant.int 3 + %int0_5540 = torch.constant.int 0 + %int64_5541 = torch.constant.int 64 + %int1_5542 = torch.constant.int 1 + %5263 = torch.aten.slice.Tensor %5261, %int3_5539, %int0_5540, %int64_5541, %int1_5542 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_5543 = torch.constant.int 3 + %int64_5544 = torch.constant.int 64 + %int9223372036854775807_5545 = torch.constant.int 9223372036854775807 + %int1_5546 = torch.constant.int 1 + %5264 = torch.aten.slice.Tensor %5261, %int3_5543, %int64_5544, %int9223372036854775807_5545, %int1_5546 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %5265 = torch.aten.neg %5264 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %5266 = torch.prim.ListConstruct %5265, %5263 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_5547 = torch.constant.int -1 + %5267 = torch.aten.cat %5266, %int-1_5547 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %5268 = torch.aten.mul.Tensor %5267, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_5548 = torch.constant.int 1 + %5269 = torch.aten.add.Tensor %5262, %5268, %int1_5548 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_5549 = torch.constant.int 1 + %int2_5550 = torch.constant.int 2 + %5270 = torch.aten.transpose.int %5269, %int1_5549, %int2_5550 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> %int1_5551 = torch.constant.int 1 - %int8_5552 = torch.constant.int 8 - %int128_5553 = torch.constant.int 128 - %4502 = torch.prim.ListConstruct %int4_5550, %int1_5551, %int8_5552, %int128_5553 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4503 = torch.aten.view %4493, %4502 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_5554 = torch.constant.int 4 - %int1_5555 = torch.constant.int 1 - %int8_5556 = torch.constant.int 8 - %int128_5557 = torch.constant.int 128 - %4504 = torch.prim.ListConstruct %int4_5554, %int1_5555, %int8_5556, %int128_5557 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4505 = torch.aten.view %4499, %4504 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_5558 = torch.constant.int 6 - %4506 = torch.prims.convert_element_type %4501, %int6_5558 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %4507 = torch_c.to_builtin_tensor %4506 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %4508 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %4509 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%4507, %4508) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %4510 = torch_c.from_builtin_tensor %4509 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_5559 = torch.constant.int 5 - %4511 = torch.prims.convert_element_type %4510, %int5_5559 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_5560 = torch.constant.int 6 - %4512 = torch.prims.convert_element_type %4503, %int6_5560 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %4513 = torch_c.to_builtin_tensor %4512 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %4514 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %4515 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%4513, %4514) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %4516 = torch_c.from_builtin_tensor %4515 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_5561 = torch.constant.int 5 - %4517 = torch.prims.convert_element_type %4516, %int5_5561 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_5562 = torch.constant.int 32 - %4518 = torch.aten.floor_divide.Scalar %arg2, %int32_5562 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int2_5552 = torch.constant.int 2 + %5271 = torch.aten.transpose.int %5258, %int1_5551, %int2_5552 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %5272 = torch.aten.mul.Tensor %5271, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_5553 = torch.constant.int 3 + %int0_5554 = torch.constant.int 0 + %int64_5555 = torch.constant.int 64 + %int1_5556 = torch.constant.int 1 + %5273 = torch.aten.slice.Tensor %5271, %int3_5553, %int0_5554, %int64_5555, %int1_5556 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_5557 = torch.constant.int 3 + %int64_5558 = torch.constant.int 64 + %int9223372036854775807_5559 = torch.constant.int 9223372036854775807 + %int1_5560 = torch.constant.int 1 + %5274 = torch.aten.slice.Tensor %5271, %int3_5557, %int64_5558, %int9223372036854775807_5559, %int1_5560 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %5275 = torch.aten.neg %5274 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %5276 = torch.prim.ListConstruct %5275, %5273 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_5561 = torch.constant.int -1 + %5277 = torch.aten.cat %5276, %int-1_5561 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %5278 = torch.aten.mul.Tensor %5277, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_5562 = torch.constant.int 1 + %5279 = torch.aten.add.Tensor %5272, %5278, %int1_5562 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> %int1_5563 = torch.constant.int 1 - %4519 = torch.aten.unsqueeze %4518, %int1_5563 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5564 = torch.constant.int 1 - %false_5565 = torch.constant.bool false - %4520 = torch.aten.gather %arg3, %int1_5564, %4519, %false_5565 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_5566 = torch.constant.int 32 - %4521 = torch.aten.remainder.Scalar %arg2, %int32_5566 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int2_5564 = torch.constant.int 2 + %5280 = torch.aten.transpose.int %5279, %int1_5563, %int2_5564 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_5565 = torch.constant.int 32 + %5281 = torch.aten.floor_divide.Scalar %arg2, %int32_5565 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_5566 = torch.constant.int 1 + %5282 = torch.aten.unsqueeze %5281, %int1_5566 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> %int1_5567 = torch.constant.int 1 - %4522 = torch.aten.unsqueeze %4521, %int1_5567 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_5568 = torch.constant.none - %4523 = torch.aten.clone %225, %none_5568 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_5569 = torch.constant.int 0 - %4524 = torch.aten.unsqueeze %4523, %int0_5569 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_5570 = torch.constant.int 4 + %false_5568 = torch.constant.bool false + %5283 = torch.aten.gather %arg3, %int1_5567, %5282, %false_5568 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_5569 = torch.constant.int 4 + %int1_5570 = torch.constant.int 1 %int1_5571 = torch.constant.int 1 - %4525 = torch.prim.ListConstruct %int4_5570, %int1_5571 : (!torch.int, !torch.int) -> !torch.list - %int1_5572 = torch.constant.int 1 - %int1_5573 = torch.constant.int 1 - %4526 = torch.prim.ListConstruct %int1_5572, %int1_5573 : (!torch.int, !torch.int) -> !torch.list - %int4_5574 = torch.constant.int 4 - %int0_5575 = torch.constant.int 0 - %cpu_5576 = torch.constant.device "cpu" - %false_5577 = torch.constant.bool false - %4527 = torch.aten.empty_strided %4525, %4526, %int4_5574, %int0_5575, %cpu_5576, %false_5577 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int20 = torch.constant.int 20 - %4528 = torch.aten.fill.Scalar %4527, %int20 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_5578 = torch.constant.int 4 - %int1_5579 = torch.constant.int 1 - %4529 = torch.prim.ListConstruct %int4_5578, %int1_5579 : (!torch.int, !torch.int) -> !torch.list - %4530 = torch.aten.repeat %4524, %4529 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_5580 = torch.constant.int 32 - %4531 = torch.aten.mul.Scalar %4520, %int32_5580 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %5284 = torch.prim.ListConstruct %int4_5569, %int1_5570, %int1_5571 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5285 = torch.aten.view %5283, %5284 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_5572 = torch.constant.int 32 + %5286 = torch.aten.remainder.Scalar %arg2, %int32_5572 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_5573 = torch.constant.int 4 + %int1_5574 = torch.constant.int 1 + %int1_5575 = torch.constant.int 1 + %5287 = torch.prim.ListConstruct %int4_5573, %int1_5574, %int1_5575 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5288 = torch.aten.view %5286, %5287 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_5576 = torch.constant.int 8 + %none_5577 = torch.constant.none + %none_5578 = torch.constant.none + %cpu_5579 = torch.constant.device "cpu" + %false_5580 = torch.constant.bool false + %5289 = torch.aten.arange %int8_5576, %none_5577, %none_5578, %cpu_5579, %false_5580 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> %int1_5581 = torch.constant.int 1 - %4532 = torch.aten.add.Tensor %4531, %4528, %int1_5581 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_5582 = torch.constant.int 2 - %4533 = torch.aten.mul.Scalar %4532, %int2_5582 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5583 = torch.constant.int 1 - %4534 = torch.aten.add.Tensor %4533, %4530, %int1_5583 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_5584 = torch.constant.int 32 - %4535 = torch.aten.mul.Scalar %4534, %int32_5584 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_5582 = torch.constant.int 1 + %int8_5583 = torch.constant.int 8 + %5290 = torch.prim.ListConstruct %int1_5581, %int1_5582, %int8_5583 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5291 = torch.aten.view %5289, %5290 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_5584 = torch.constant.none + %5292 = torch.aten.clone %314, %none_5584 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5293 = torch.aten.detach %5292 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5294 = torch.aten.detach %5293 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5295 = torch.aten.detach %5294 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %int1_5585 = torch.constant.int 1 - %4536 = torch.aten.add.Tensor %4535, %4522, %int1_5585 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_5586 = torch.constant.int 32 - %int2_5587 = torch.constant.int 2 + %int1_5586 = torch.constant.int 1 + %int1_5587 = torch.constant.int 1 + %5296 = torch.prim.ListConstruct %int1_5585, %int1_5586, %int1_5587 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5297 = torch.aten.view %5295, %5296 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> %int32_5588 = torch.constant.int 32 - %int8_5589 = torch.constant.int 8 - %int128_5590 = torch.constant.int 128 - %4537 = torch.prim.ListConstruct %437, %int32_5586, %int2_5587, %int32_5588, %int8_5589, %int128_5590 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4538 = torch.aten.view %4374, %4537 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4538, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5591 = torch.constant.int 32 - %4539 = torch.aten.mul.int %437, %int32_5591 : !torch.int, !torch.int -> !torch.int - %int2_5592 = torch.constant.int 2 - %4540 = torch.aten.mul.int %4539, %int2_5592 : !torch.int, !torch.int -> !torch.int - %int32_5593 = torch.constant.int 32 - %4541 = torch.aten.mul.int %4540, %int32_5593 : !torch.int, !torch.int -> !torch.int - %int8_5594 = torch.constant.int 8 - %int128_5595 = torch.constant.int 128 - %4542 = torch.prim.ListConstruct %4541, %int8_5594, %int128_5595 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4543 = torch.aten.view %4538, %4542 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4543, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %4544 = torch.prim.ListConstruct %4536 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_5596 = torch.constant.bool false - %4545 = torch.aten.index_put %4543, %4544, %4517, %false_5596 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4545, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> + %5298 = torch.aten.mul.Scalar %5285, %int32_5588 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int22 = torch.constant.int 22 + %int1_5589 = torch.constant.int 1 + %5299 = torch.aten.add.Scalar %5298, %int22, %int1_5589 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_5590 = torch.constant.int 2 + %5300 = torch.aten.mul.Scalar %5299, %int2_5590 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_5591 = torch.constant.int 1 + %5301 = torch.aten.add.Tensor %5300, %5297, %int1_5591 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_5592 = torch.constant.int 8 + %5302 = torch.aten.mul.Scalar %5301, %int8_5592 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_5593 = torch.constant.int 1 + %5303 = torch.aten.add.Tensor %5302, %5291, %int1_5593 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_5594 = torch.constant.int 32 + %5304 = torch.aten.mul.Scalar %5303, %int32_5594 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_5595 = torch.constant.int 1 + %5305 = torch.aten.add.Tensor %5304, %5288, %int1_5595 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_5596 = torch.constant.int 5 + %5306 = torch.prims.convert_element_type %5280, %int5_5596 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> %int32_5597 = torch.constant.int 32 %int2_5598 = torch.constant.int 2 - %int32_5599 = torch.constant.int 32 - %int8_5600 = torch.constant.int 8 + %int8_5599 = torch.constant.int 8 + %int32_5600 = torch.constant.int 32 %int128_5601 = torch.constant.int 128 - %4546 = torch.prim.ListConstruct %437, %int32_5597, %int2_5598, %int32_5599, %int8_5600, %int128_5601 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4547 = torch.aten.view %4545, %4546 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4547, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5602 = torch.constant.int 2097152 - %4548 = torch.prim.ListConstruct %437, %int2097152_5602 : (!torch.int, !torch.int) -> !torch.list - %4549 = torch.aten.view %4547, %4548 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4549, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_5603 = torch.constant.int 32 - %int2_5604 = torch.constant.int 2 - %int32_5605 = torch.constant.int 32 + %5307 = torch.prim.ListConstruct %456, %int32_5597, %int2_5598, %int8_5599, %int32_5600, %int128_5601 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5308 = torch.aten.view %5128, %5307 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5308, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_5602 = torch.constant.int 128 + %5309 = torch.prim.ListConstruct %596, %int128_5602 : (!torch.int, !torch.int) -> !torch.list + %5310 = torch.aten.view %5308, %5309 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5310, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %5311 = torch.prim.ListConstruct %5305 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_5603 = torch.constant.bool false + %5312 = torch.aten.index_put %5310, %5311, %5306, %false_5603 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5312, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_5604 = torch.constant.int 32 + %int2_5605 = torch.constant.int 2 %int8_5606 = torch.constant.int 8 - %int128_5607 = torch.constant.int 128 - %4550 = torch.prim.ListConstruct %437, %int32_5603, %int2_5604, %int32_5605, %int8_5606, %int128_5607 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4551 = torch.aten.view %4549, %4550 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4551, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_5608 = torch.constant.int 8 - %int128_5609 = torch.constant.int 128 - %4552 = torch.prim.ListConstruct %4541, %int8_5608, %int128_5609 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4553 = torch.aten.view %4551, %4552 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4553, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> + %int32_5607 = torch.constant.int 32 + %int128_5608 = torch.constant.int 128 + %5313 = torch.prim.ListConstruct %456, %int32_5604, %int2_5605, %int8_5606, %int32_5607, %int128_5608 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5314 = torch.aten.view %5312, %5313 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5314, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_5609 = torch.constant.int 2097152 + %5315 = torch.prim.ListConstruct %456, %int2097152_5609 : (!torch.int, !torch.int) -> !torch.list + %5316 = torch.aten.view %5314, %5315 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5316, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> %int32_5610 = torch.constant.int 32 - %4554 = torch.aten.floor_divide.Scalar %arg2, %int32_5610 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_5611 = torch.constant.int 1 - %4555 = torch.aten.unsqueeze %4554, %int1_5611 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5612 = torch.constant.int 1 - %false_5613 = torch.constant.bool false - %4556 = torch.aten.gather %arg3, %int1_5612, %4555, %false_5613 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_5614 = torch.constant.int 32 - %4557 = torch.aten.remainder.Scalar %arg2, %int32_5614 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_5615 = torch.constant.int 1 - %4558 = torch.aten.unsqueeze %4557, %int1_5615 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int2_5611 = torch.constant.int 2 + %int8_5612 = torch.constant.int 8 + %int32_5613 = torch.constant.int 32 + %int128_5614 = torch.constant.int 128 + %5317 = torch.prim.ListConstruct %456, %int32_5610, %int2_5611, %int8_5612, %int32_5613, %int128_5614 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5318 = torch.aten.view %5316, %5317 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5318, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_5615 = torch.constant.int 128 + %5319 = torch.prim.ListConstruct %596, %int128_5615 : (!torch.int, !torch.int) -> !torch.list + %5320 = torch.aten.view %5318, %5319 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5320, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> %none_5616 = torch.constant.none - %4559 = torch.aten.clone %226, %none_5616 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_5617 = torch.constant.int 0 - %4560 = torch.aten.unsqueeze %4559, %int0_5617 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_5618 = torch.constant.int 4 + %5321 = torch.aten.clone %315, %none_5616 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5322 = torch.aten.detach %5321 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5323 = torch.aten.detach %5322 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5324 = torch.aten.detach %5323 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_5617 = torch.constant.int 1 + %int1_5618 = torch.constant.int 1 %int1_5619 = torch.constant.int 1 - %4561 = torch.prim.ListConstruct %int4_5618, %int1_5619 : (!torch.int, !torch.int) -> !torch.list - %int1_5620 = torch.constant.int 1 - %int1_5621 = torch.constant.int 1 - %4562 = torch.prim.ListConstruct %int1_5620, %int1_5621 : (!torch.int, !torch.int) -> !torch.list - %int4_5622 = torch.constant.int 4 - %int0_5623 = torch.constant.int 0 - %cpu_5624 = torch.constant.device "cpu" - %false_5625 = torch.constant.bool false - %4563 = torch.aten.empty_strided %4561, %4562, %int4_5622, %int0_5623, %cpu_5624, %false_5625 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int20_5626 = torch.constant.int 20 - %4564 = torch.aten.fill.Scalar %4563, %int20_5626 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_5627 = torch.constant.int 4 + %5325 = torch.prim.ListConstruct %int1_5617, %int1_5618, %int1_5619 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5326 = torch.aten.view %5324, %5325 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_5620 = torch.constant.int 32 + %5327 = torch.aten.mul.Scalar %5285, %int32_5620 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int22_5621 = torch.constant.int 22 + %int1_5622 = torch.constant.int 1 + %5328 = torch.aten.add.Scalar %5327, %int22_5621, %int1_5622 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_5623 = torch.constant.int 2 + %5329 = torch.aten.mul.Scalar %5328, %int2_5623 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_5624 = torch.constant.int 1 + %5330 = torch.aten.add.Tensor %5329, %5326, %int1_5624 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_5625 = torch.constant.int 8 + %5331 = torch.aten.mul.Scalar %5330, %int8_5625 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_5626 = torch.constant.int 1 + %5332 = torch.aten.add.Tensor %5331, %5291, %int1_5626 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_5627 = torch.constant.int 32 + %5333 = torch.aten.mul.Scalar %5332, %int32_5627 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_5628 = torch.constant.int 1 - %4565 = torch.prim.ListConstruct %int4_5627, %int1_5628 : (!torch.int, !torch.int) -> !torch.list - %4566 = torch.aten.repeat %4560, %4565 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_5629 = torch.constant.int 32 - %4567 = torch.aten.mul.Scalar %4556, %int32_5629 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5630 = torch.constant.int 1 - %4568 = torch.aten.add.Tensor %4567, %4564, %int1_5630 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_5631 = torch.constant.int 2 - %4569 = torch.aten.mul.Scalar %4568, %int2_5631 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5632 = torch.constant.int 1 - %4570 = torch.aten.add.Tensor %4569, %4566, %int1_5632 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_5633 = torch.constant.int 32 - %4571 = torch.aten.mul.Scalar %4570, %int32_5633 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5634 = torch.constant.int 1 - %4572 = torch.aten.add.Tensor %4571, %4558, %int1_5634 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %4573 = torch.prim.ListConstruct %4572 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_5635 = torch.constant.bool false - %4574 = torch.aten.index_put %4553, %4573, %4505, %false_5635 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4574, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_5636 = torch.constant.int 32 - %int2_5637 = torch.constant.int 2 - %int32_5638 = torch.constant.int 32 - %int8_5639 = torch.constant.int 8 - %int128_5640 = torch.constant.int 128 - %4575 = torch.prim.ListConstruct %437, %int32_5636, %int2_5637, %int32_5638, %int8_5639, %int128_5640 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4576 = torch.aten.view %4574, %4575 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4576, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5641 = torch.constant.int 2097152 - %4577 = torch.prim.ListConstruct %437, %int2097152_5641 : (!torch.int, !torch.int) -> !torch.list - %4578 = torch.aten.view %4576, %4577 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4578, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_5642 = torch.constant.int 4 - %4579 = torch.prim.ListConstruct %int4_5642, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_5643 = torch.constant.int 1 - %4580 = torch.prim.ListConstruct %358, %int1_5643 : (!torch.int, !torch.int) -> !torch.list - %int4_5644 = torch.constant.int 4 - %int0_5645 = torch.constant.int 0 - %cpu_5646 = torch.constant.device "cpu" - %false_5647 = torch.constant.bool false - %4581 = torch.aten.empty_strided %4579, %4580, %int4_5644, %int0_5645, %cpu_5646, %false_5647 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4581, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int20_5648 = torch.constant.int 20 - %4582 = torch.aten.fill.Scalar %4581, %int20_5648 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4582, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_5649 = torch.constant.int 32 - %4583 = torch.aten.mul.Scalar %arg3, %int32_5649 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4583, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_5650 = torch.constant.int 1 - %4584 = torch.aten.add.Tensor %4583, %4582, %int1_5650 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4584, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_5651 = torch.constant.int 4 - %4585 = torch.aten.mul.int %int4_5651, %358 : !torch.int, !torch.int -> !torch.int - %4586 = torch.prim.ListConstruct %4585 : (!torch.int) -> !torch.list - %4587 = torch.aten.view %4584, %4586 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4587, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_5652 = torch.constant.int 32 - %int2_5653 = torch.constant.int 2 - %int32_5654 = torch.constant.int 32 - %int8_5655 = torch.constant.int 8 - %int128_5656 = torch.constant.int 128 - %4588 = torch.prim.ListConstruct %437, %int32_5652, %int2_5653, %int32_5654, %int8_5655, %int128_5656 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4589 = torch.aten.view %4578, %4588 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4589, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5657 = torch.constant.int 32 - %4590 = torch.aten.mul.int %437, %int32_5657 : !torch.int, !torch.int -> !torch.int - %int2_5658 = torch.constant.int 2 - %int32_5659 = torch.constant.int 32 - %int8_5660 = torch.constant.int 8 - %int128_5661 = torch.constant.int 128 - %4591 = torch.prim.ListConstruct %4590, %int2_5658, %int32_5659, %int8_5660, %int128_5661 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4592 = torch.aten.view %4589, %4591 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %4592, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_5662 = torch.constant.int 0 - %4593 = torch.aten.index_select %4592, %int0_5662, %4587 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %4593, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_5663 = torch.constant.int 4 - %int2_5664 = torch.constant.int 2 - %int32_5665 = torch.constant.int 32 - %int8_5666 = torch.constant.int 8 - %int128_5667 = torch.constant.int 128 - %4594 = torch.prim.ListConstruct %int4_5663, %358, %int2_5664, %int32_5665, %int8_5666, %int128_5667 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4595 = torch.aten.view %4593, %4594 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4595, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_5668 = torch.constant.int 0 - %int0_5669 = torch.constant.int 0 - %int9223372036854775807_5670 = torch.constant.int 9223372036854775807 - %int1_5671 = torch.constant.int 1 - %4596 = torch.aten.slice.Tensor %4595, %int0_5668, %int0_5669, %int9223372036854775807_5670, %int1_5671 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4596, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_5672 = torch.constant.int 1 - %int0_5673 = torch.constant.int 0 - %int9223372036854775807_5674 = torch.constant.int 9223372036854775807 - %int1_5675 = torch.constant.int 1 - %4597 = torch.aten.slice.Tensor %4596, %int1_5672, %int0_5673, %int9223372036854775807_5674, %int1_5675 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4597, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_5676 = torch.constant.int 2 + %5334 = torch.aten.add.Tensor %5333, %5288, %int1_5628 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_5629 = torch.constant.int 5 + %5335 = torch.prims.convert_element_type %5260, %int5_5629 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %5336 = torch.prim.ListConstruct %5334 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_5630 = torch.constant.bool false + %5337 = torch.aten.index_put %5320, %5336, %5335, %false_5630 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5337, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_5631 = torch.constant.int 32 + %int2_5632 = torch.constant.int 2 + %int8_5633 = torch.constant.int 8 + %int32_5634 = torch.constant.int 32 + %int128_5635 = torch.constant.int 128 + %5338 = torch.prim.ListConstruct %456, %int32_5631, %int2_5632, %int8_5633, %int32_5634, %int128_5635 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5339 = torch.aten.view %5337, %5338 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5339, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_5636 = torch.constant.int 2097152 + %5340 = torch.prim.ListConstruct %456, %int2097152_5636 : (!torch.int, !torch.int) -> !torch.list + %5341 = torch.aten.view %5339, %5340 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5341, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_5637 = torch.constant.none + %5342 = torch.aten.clone %316, %none_5637 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5343 = torch.aten.detach %5342 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5344 = torch.aten.detach %5343 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5345 = torch.aten.detach %5344 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_5638 = torch.constant.none + %5346 = torch.aten.clone %317, %none_5638 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5347 = torch.aten.detach %5346 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5348 = torch.aten.detach %5347 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5349 = torch.aten.detach %5348 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_5639 = torch.constant.none + %5350 = torch.aten.clone %318, %none_5639 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5351 = torch.aten.detach %5350 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5352 = torch.aten.detach %5351 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5353 = torch.aten.detach %5352 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_5640 = torch.constant.int 32 + %int2_5641 = torch.constant.int 2 + %int8_5642 = torch.constant.int 8 + %int32_5643 = torch.constant.int 32 + %int128_5644 = torch.constant.int 128 + %5354 = torch.prim.ListConstruct %456, %int32_5640, %int2_5641, %int8_5642, %int32_5643, %int128_5644 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5355 = torch.aten.view %5341, %5354 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5355, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %5356 = torch_c.to_builtin_tensor %5355 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %5357 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_5645 = tensor.cast %5357 : tensor<4x?xi64> to tensor + %5358 = torch_c.to_builtin_tensor %5345 : !torch.vtensor<[],si64> -> tensor + %5359 = torch_c.to_builtin_tensor %5349 : !torch.vtensor<[],si64> -> tensor + %5360 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%5356, %cast_5645, %5358, %5359) : (tensor, tensor, tensor, tensor) -> tensor + %cast_5646 = tensor.cast %5360 : tensor to tensor<4x?x8x32x128xf16> + %5361 = torch_c.from_builtin_tensor %cast_5646 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %5361, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %5362 = torch_c.to_builtin_tensor %5355 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %5363 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_5647 = tensor.cast %5363 : tensor<4x?xi64> to tensor + %5364 = torch_c.to_builtin_tensor %5345 : !torch.vtensor<[],si64> -> tensor + %5365 = torch_c.to_builtin_tensor %5353 : !torch.vtensor<[],si64> -> tensor + %5366 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%5362, %cast_5647, %5364, %5365) : (tensor, tensor, tensor, tensor) -> tensor + %cast_5648 = tensor.cast %5366 : tensor to tensor<4x?x8x32x128xf16> + %5367 = torch_c.from_builtin_tensor %cast_5648 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %5367, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_5649 = torch.constant.int 2 + %int3_5650 = torch.constant.int 3 + %5368 = torch.aten.transpose.int %5361, %int2_5649, %int3_5650 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5368, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_5651 = torch.constant.int 0 + %5369 = torch.aten.clone %5368, %int0_5651 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5369, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_5652 = torch.constant.int 4 + %int8_5653 = torch.constant.int 8 + %int128_5654 = torch.constant.int 128 + %5370 = torch.prim.ListConstruct %int4_5652, %457, %int8_5653, %int128_5654 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5371 = torch.aten._unsafe_view %5369, %5370 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5371, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_5655 = torch.constant.int 2 + %int3_5656 = torch.constant.int 3 + %5372 = torch.aten.transpose.int %5367, %int2_5655, %int3_5656 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5372, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_5657 = torch.constant.int 0 + %5373 = torch.aten.clone %5372, %int0_5657 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5373, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_5658 = torch.constant.int 4 + %int8_5659 = torch.constant.int 8 + %int128_5660 = torch.constant.int 128 + %5374 = torch.prim.ListConstruct %int4_5658, %457, %int8_5659, %int128_5660 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5375 = torch.aten._unsafe_view %5373, %5374 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5375, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_5661 = torch.constant.int -2 + %5376 = torch.aten.unsqueeze %5371, %int-2_5661 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5376, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_5662 = torch.constant.int 4 + %int8_5663 = torch.constant.int 8 + %int4_5664 = torch.constant.int 4 + %int128_5665 = torch.constant.int 128 + %5377 = torch.prim.ListConstruct %int4_5662, %457, %int8_5663, %int4_5664, %int128_5665 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_5666 = torch.constant.bool false + %5378 = torch.aten.expand %5376, %5377, %false_5666 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5378, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_5667 = torch.constant.int 0 + %5379 = torch.aten.clone %5378, %int0_5667 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5379, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_5668 = torch.constant.int 4 + %int32_5669 = torch.constant.int 32 + %int128_5670 = torch.constant.int 128 + %5380 = torch.prim.ListConstruct %int4_5668, %457, %int32_5669, %int128_5670 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5381 = torch.aten._unsafe_view %5379, %5380 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5381, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_5671 = torch.constant.int -2 + %5382 = torch.aten.unsqueeze %5375, %int-2_5671 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5382, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_5672 = torch.constant.int 4 + %int8_5673 = torch.constant.int 8 + %int4_5674 = torch.constant.int 4 + %int128_5675 = torch.constant.int 128 + %5383 = torch.prim.ListConstruct %int4_5672, %457, %int8_5673, %int4_5674, %int128_5675 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_5676 = torch.constant.bool false + %5384 = torch.aten.expand %5382, %5383, %false_5676 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5384, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int0_5677 = torch.constant.int 0 - %4598 = torch.aten.select.int %4597, %int2_5676, %int0_5677 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4598, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_5678 = torch.constant.int 32 - %4599 = torch.aten.mul.int %358, %int32_5678 : !torch.int, !torch.int -> !torch.int - %int2_5679 = torch.constant.int 2 - %int0_5680 = torch.constant.int 0 + %5385 = torch.aten.clone %5384, %int0_5677 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5385, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_5678 = torch.constant.int 4 + %int32_5679 = torch.constant.int 32 + %int128_5680 = torch.constant.int 128 + %5386 = torch.prim.ListConstruct %int4_5678, %457, %int32_5679, %int128_5680 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5387 = torch.aten._unsafe_view %5385, %5386 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5387, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int1_5681 = torch.constant.int 1 - %4600 = torch.aten.slice.Tensor %4598, %int2_5679, %int0_5680, %4599, %int1_5681 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4600, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_5682 = torch.constant.int 0 - %4601 = torch.aten.clone %4600, %int0_5682 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4601, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int2_5682 = torch.constant.int 2 + %5388 = torch.aten.transpose.int %5270, %int1_5681, %int2_5682 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_5683 = torch.constant.int 1 - %4602 = torch.aten.size.int %4597, %int1_5683 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_5684 = torch.constant.int 32 - %4603 = torch.aten.mul.int %4602, %int32_5684 : !torch.int, !torch.int -> !torch.int - %int4_5685 = torch.constant.int 4 - %int8_5686 = torch.constant.int 8 - %int128_5687 = torch.constant.int 128 - %4604 = torch.prim.ListConstruct %int4_5685, %4603, %int8_5686, %int128_5687 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4605 = torch.aten._unsafe_view %4601, %4604 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4605, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_5688 = torch.constant.int 0 - %int0_5689 = torch.constant.int 0 - %int9223372036854775807_5690 = torch.constant.int 9223372036854775807 - %int1_5691 = torch.constant.int 1 - %4606 = torch.aten.slice.Tensor %4605, %int0_5688, %int0_5689, %int9223372036854775807_5690, %int1_5691 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4606, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_5692 = torch.constant.int 0 - %int0_5693 = torch.constant.int 0 - %int9223372036854775807_5694 = torch.constant.int 9223372036854775807 - %int1_5695 = torch.constant.int 1 - %4607 = torch.aten.slice.Tensor %4595, %int0_5692, %int0_5693, %int9223372036854775807_5694, %int1_5695 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4607, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_5696 = torch.constant.int 1 - %int0_5697 = torch.constant.int 0 - %int9223372036854775807_5698 = torch.constant.int 9223372036854775807 - %int1_5699 = torch.constant.int 1 - %4608 = torch.aten.slice.Tensor %4607, %int1_5696, %int0_5697, %int9223372036854775807_5698, %int1_5699 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4608, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_5700 = torch.constant.int 2 + %int2_5684 = torch.constant.int 2 + %5389 = torch.aten.transpose.int %5381, %int1_5683, %int2_5684 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5389, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_5685 = torch.constant.int 1 + %int2_5686 = torch.constant.int 2 + %5390 = torch.aten.transpose.int %5387, %int1_5685, %int2_5686 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5390, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_5687 = torch.constant.float 0.000000e+00 + %false_5688 = torch.constant.bool false + %none_5689 = torch.constant.none + %5391:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5388, %5389, %5390, %float0.000000e00_5687, %false_5688, %470, %none_5689) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_5690 = torch.constant.int 1 + %int2_5691 = torch.constant.int 2 + %5392 = torch.aten.transpose.int %5391#0, %int1_5690, %int2_5691 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_5692 = torch.constant.int 4 + %int1_5693 = torch.constant.int 1 + %int4096_5694 = torch.constant.int 4096 + %5393 = torch.prim.ListConstruct %int4_5692, %int1_5693, %int4096_5694 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5394 = torch.aten.view %5392, %5393 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_5695 = torch.constant.int -2 + %int-1_5696 = torch.constant.int -1 + %5395 = torch.aten.transpose.int %319, %int-2_5695, %int-1_5696 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_5697 = torch.constant.int 5 + %5396 = torch.prims.convert_element_type %5395, %int5_5697 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_5698 = torch.constant.int 4 + %int4096_5699 = torch.constant.int 4096 + %5397 = torch.prim.ListConstruct %int4_5698, %int4096_5699 : (!torch.int, !torch.int) -> !torch.list + %5398 = torch.aten.view %5394, %5397 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5399 = torch.aten.mm %5398, %5396 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_5700 = torch.constant.int 4 %int1_5701 = torch.constant.int 1 - %4609 = torch.aten.select.int %4608, %int2_5700, %int1_5701 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4609, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_5702 = torch.constant.int 2 - %int0_5703 = torch.constant.int 0 - %int1_5704 = torch.constant.int 1 - %4610 = torch.aten.slice.Tensor %4609, %int2_5702, %int0_5703, %4599, %int1_5704 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4610, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_5705 = torch.constant.int 0 - %4611 = torch.aten.clone %4610, %int0_5705 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4611, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_5706 = torch.constant.int 1 - %4612 = torch.aten.size.int %4608, %int1_5706 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_5707 = torch.constant.int 32 - %4613 = torch.aten.mul.int %4612, %int32_5707 : !torch.int, !torch.int -> !torch.int - %int4_5708 = torch.constant.int 4 - %int8_5709 = torch.constant.int 8 - %int128_5710 = torch.constant.int 128 - %4614 = torch.prim.ListConstruct %int4_5708, %4613, %int8_5709, %int128_5710 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4615 = torch.aten._unsafe_view %4611, %4614 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4615, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_5711 = torch.constant.int 0 - %int0_5712 = torch.constant.int 0 - %int9223372036854775807_5713 = torch.constant.int 9223372036854775807 - %int1_5714 = torch.constant.int 1 - %4616 = torch.aten.slice.Tensor %4615, %int0_5711, %int0_5712, %int9223372036854775807_5713, %int1_5714 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4616, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_5715 = torch.constant.int -2 - %4617 = torch.aten.unsqueeze %4606, %int-2_5715 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4617, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_5716 = torch.constant.int 1 - %4618 = torch.aten.size.int %4605, %int1_5716 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_5717 = torch.constant.int 4 - %int8_5718 = torch.constant.int 8 - %int4_5719 = torch.constant.int 4 - %int128_5720 = torch.constant.int 128 - %4619 = torch.prim.ListConstruct %int4_5717, %4618, %int8_5718, %int4_5719, %int128_5720 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_5721 = torch.constant.bool false - %4620 = torch.aten.expand %4617, %4619, %false_5721 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4620, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_5722 = torch.constant.int 0 - %4621 = torch.aten.clone %4620, %int0_5722 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4621, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_5723 = torch.constant.int 4 - %int32_5724 = torch.constant.int 32 - %int128_5725 = torch.constant.int 128 - %4622 = torch.prim.ListConstruct %int4_5723, %4618, %int32_5724, %int128_5725 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4623 = torch.aten._unsafe_view %4621, %4622 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4623, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_5726 = torch.constant.int -2 - %4624 = torch.aten.unsqueeze %4616, %int-2_5726 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4624, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4096_5702 = torch.constant.int 4096 + %5400 = torch.prim.ListConstruct %int4_5700, %int1_5701, %int4096_5702 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5401 = torch.aten.view %5399, %5400 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_5703 = torch.constant.int 1 + %5402 = torch.aten.add.Tensor %5223, %5401, %int1_5703 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_5704 = torch.constant.int 6 + %5403 = torch.prims.convert_element_type %5402, %int6_5704 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_5705 = torch.constant.int 2 + %5404 = torch.aten.pow.Tensor_Scalar %5403, %int2_5705 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_5706 = torch.constant.int -1 + %5405 = torch.prim.ListConstruct %int-1_5706 : (!torch.int) -> !torch.list + %true_5707 = torch.constant.bool true + %none_5708 = torch.constant.none + %5406 = torch.aten.mean.dim %5404, %5405, %true_5707, %none_5708 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_5709 = torch.constant.float 9.9999997473787516E-6 + %int1_5710 = torch.constant.int 1 + %5407 = torch.aten.add.Scalar %5406, %float9.999990e-06_5709, %int1_5710 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %5408 = torch.aten.rsqrt %5407 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %5409 = torch.aten.mul.Tensor %5403, %5408 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_5711 = torch.constant.int 5 + %5410 = torch.prims.convert_element_type %5409, %int5_5711 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %5411 = torch.aten.mul.Tensor %320, %5410 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_5712 = torch.constant.int 5 + %5412 = torch.prims.convert_element_type %5411, %int5_5712 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_5713 = torch.constant.int -2 + %int-1_5714 = torch.constant.int -1 + %5413 = torch.aten.transpose.int %321, %int-2_5713, %int-1_5714 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_5715 = torch.constant.int 5 + %5414 = torch.prims.convert_element_type %5413, %int5_5715 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_5716 = torch.constant.int 4 + %int4096_5717 = torch.constant.int 4096 + %5415 = torch.prim.ListConstruct %int4_5716, %int4096_5717 : (!torch.int, !torch.int) -> !torch.list + %5416 = torch.aten.view %5412, %5415 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5417 = torch.aten.mm %5416, %5414 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_5718 = torch.constant.int 4 + %int1_5719 = torch.constant.int 1 + %int14336_5720 = torch.constant.int 14336 + %5418 = torch.prim.ListConstruct %int4_5718, %int1_5719, %int14336_5720 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5419 = torch.aten.view %5417, %5418 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %5420 = torch.aten.silu %5419 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_5721 = torch.constant.int -2 + %int-1_5722 = torch.constant.int -1 + %5421 = torch.aten.transpose.int %322, %int-2_5721, %int-1_5722 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_5723 = torch.constant.int 5 + %5422 = torch.prims.convert_element_type %5421, %int5_5723 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_5724 = torch.constant.int 4 + %int4096_5725 = torch.constant.int 4096 + %5423 = torch.prim.ListConstruct %int4_5724, %int4096_5725 : (!torch.int, !torch.int) -> !torch.list + %5424 = torch.aten.view %5412, %5423 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5425 = torch.aten.mm %5424, %5422 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_5726 = torch.constant.int 4 %int1_5727 = torch.constant.int 1 - %4625 = torch.aten.size.int %4615, %int1_5727 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_5728 = torch.constant.int 4 - %int8_5729 = torch.constant.int 8 - %int4_5730 = torch.constant.int 4 - %int128_5731 = torch.constant.int 128 - %4626 = torch.prim.ListConstruct %int4_5728, %4625, %int8_5729, %int4_5730, %int128_5731 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_5732 = torch.constant.bool false - %4627 = torch.aten.expand %4624, %4626, %false_5732 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4627, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_5733 = torch.constant.int 0 - %4628 = torch.aten.clone %4627, %int0_5733 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4628, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int14336_5728 = torch.constant.int 14336 + %5426 = torch.prim.ListConstruct %int4_5726, %int1_5727, %int14336_5728 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5427 = torch.aten.view %5425, %5426 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %5428 = torch.aten.mul.Tensor %5420, %5427 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_5729 = torch.constant.int -2 + %int-1_5730 = torch.constant.int -1 + %5429 = torch.aten.transpose.int %323, %int-2_5729, %int-1_5730 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_5731 = torch.constant.int 5 + %5430 = torch.prims.convert_element_type %5429, %int5_5731 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_5732 = torch.constant.int 4 + %int14336_5733 = torch.constant.int 14336 + %5431 = torch.prim.ListConstruct %int4_5732, %int14336_5733 : (!torch.int, !torch.int) -> !torch.list + %5432 = torch.aten.view %5428, %5431 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %5433 = torch.aten.mm %5432, %5430 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_5734 = torch.constant.int 4 - %int32_5735 = torch.constant.int 32 - %int128_5736 = torch.constant.int 128 - %4629 = torch.prim.ListConstruct %int4_5734, %4625, %int32_5735, %int128_5736 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4630 = torch.aten._unsafe_view %4628, %4629 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4630, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_5735 = torch.constant.int 1 + %int4096_5736 = torch.constant.int 4096 + %5434 = torch.prim.ListConstruct %int4_5734, %int1_5735, %int4096_5736 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5435 = torch.aten.view %5433, %5434 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_5737 = torch.constant.int 1 - %int2_5738 = torch.constant.int 2 - %4631 = torch.aten.transpose.int %4511, %int1_5737, %int2_5738 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_5739 = torch.constant.int 1 - %int2_5740 = torch.constant.int 2 - %4632 = torch.aten.transpose.int %4623, %int1_5739, %int2_5740 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4632, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_5741 = torch.constant.int 1 - %int2_5742 = torch.constant.int 2 - %4633 = torch.aten.transpose.int %4630, %int1_5741, %int2_5742 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4633, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_5743 = torch.constant.float 0.000000e+00 - %false_5744 = torch.constant.bool false - %none_5745 = torch.constant.none - %4634:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4631, %4632, %4633, %float0.000000e00_5743, %false_5744, %368, %none_5745) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_5746 = torch.constant.int 1 - %int2_5747 = torch.constant.int 2 - %4635 = torch.aten.transpose.int %4634#0, %int1_5746, %int2_5747 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_5748 = torch.constant.int 4 - %int1_5749 = torch.constant.int 1 - %int4096_5750 = torch.constant.int 4096 - %4636 = torch.prim.ListConstruct %int4_5748, %int1_5749, %int4096_5750 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4637 = torch.aten.view %4635, %4636 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_5751 = torch.constant.int -2 - %int-1_5752 = torch.constant.int -1 - %4638 = torch.aten.transpose.int %227, %int-2_5751, %int-1_5752 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_5753 = torch.constant.int 4 + %5436 = torch.aten.add.Tensor %5402, %5435, %int1_5737 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_5738 = torch.constant.int 6 + %5437 = torch.prims.convert_element_type %5436, %int6_5738 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_5739 = torch.constant.int 2 + %5438 = torch.aten.pow.Tensor_Scalar %5437, %int2_5739 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_5740 = torch.constant.int -1 + %5439 = torch.prim.ListConstruct %int-1_5740 : (!torch.int) -> !torch.list + %true_5741 = torch.constant.bool true + %none_5742 = torch.constant.none + %5440 = torch.aten.mean.dim %5438, %5439, %true_5741, %none_5742 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_5743 = torch.constant.float 9.9999997473787516E-6 + %int1_5744 = torch.constant.int 1 + %5441 = torch.aten.add.Scalar %5440, %float9.999990e-06_5743, %int1_5744 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %5442 = torch.aten.rsqrt %5441 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %5443 = torch.aten.mul.Tensor %5437, %5442 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_5745 = torch.constant.int 5 + %5444 = torch.prims.convert_element_type %5443, %int5_5745 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %5445 = torch.aten.mul.Tensor %324, %5444 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_5746 = torch.constant.int 5 + %5446 = torch.prims.convert_element_type %5445, %int5_5746 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_5747 = torch.constant.int -2 + %int-1_5748 = torch.constant.int -1 + %5447 = torch.aten.transpose.int %325, %int-2_5747, %int-1_5748 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_5749 = torch.constant.int 5 + %5448 = torch.prims.convert_element_type %5447, %int5_5749 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_5750 = torch.constant.int 4 + %int4096_5751 = torch.constant.int 4096 + %5449 = torch.prim.ListConstruct %int4_5750, %int4096_5751 : (!torch.int, !torch.int) -> !torch.list + %5450 = torch.aten.view %5446, %5449 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5451 = torch.aten.mm %5450, %5448 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_5752 = torch.constant.int 4 + %int1_5753 = torch.constant.int 1 %int4096_5754 = torch.constant.int 4096 - %4639 = torch.prim.ListConstruct %int4_5753, %int4096_5754 : (!torch.int, !torch.int) -> !torch.list - %4640 = torch.aten.view %4637, %4639 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4641 = torch.aten.mm %4640, %4638 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_5755 = torch.constant.int 4 - %int1_5756 = torch.constant.int 1 - %int4096_5757 = torch.constant.int 4096 - %4642 = torch.prim.ListConstruct %int4_5755, %int1_5756, %int4096_5757 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4643 = torch.aten.view %4641, %4642 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_5758 = torch.constant.int 1 - %4644 = torch.aten.add.Tensor %4471, %4643, %int1_5758 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_5759 = torch.constant.int 6 - %4645 = torch.prims.convert_element_type %4644, %int6_5759 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_5760 = torch.constant.int 2 - %4646 = torch.aten.pow.Tensor_Scalar %4645, %int2_5760 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_5761 = torch.constant.int -1 - %4647 = torch.prim.ListConstruct %int-1_5761 : (!torch.int) -> !torch.list - %true_5762 = torch.constant.bool true - %none_5763 = torch.constant.none - %4648 = torch.aten.mean.dim %4646, %4647, %true_5762, %none_5763 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_5764 = torch.constant.float 9.9999997473787516E-6 - %int1_5765 = torch.constant.int 1 - %4649 = torch.aten.add.Scalar %4648, %float9.999990e-06_5764, %int1_5765 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %4650 = torch.aten.rsqrt %4649 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %4651 = torch.aten.mul.Tensor %4645, %4650 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_5766 = torch.constant.int 5 - %4652 = torch.prims.convert_element_type %4651, %int5_5766 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %4653 = torch.aten.mul.Tensor %228, %4652 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_5767 = torch.constant.int 5 - %4654 = torch.prims.convert_element_type %4653, %int5_5767 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_5768 = torch.constant.int -2 - %int-1_5769 = torch.constant.int -1 - %4655 = torch.aten.transpose.int %229, %int-2_5768, %int-1_5769 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_5770 = torch.constant.int 4 - %int4096_5771 = torch.constant.int 4096 - %4656 = torch.prim.ListConstruct %int4_5770, %int4096_5771 : (!torch.int, !torch.int) -> !torch.list - %4657 = torch.aten.view %4654, %4656 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4658 = torch.aten.mm %4657, %4655 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_5772 = torch.constant.int 4 - %int1_5773 = torch.constant.int 1 - %int14336_5774 = torch.constant.int 14336 - %4659 = torch.prim.ListConstruct %int4_5772, %int1_5773, %int14336_5774 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4660 = torch.aten.view %4658, %4659 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %4661 = torch.aten.silu %4660 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_5775 = torch.constant.int -2 - %int-1_5776 = torch.constant.int -1 - %4662 = torch.aten.transpose.int %230, %int-2_5775, %int-1_5776 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_5777 = torch.constant.int 4 - %int4096_5778 = torch.constant.int 4096 - %4663 = torch.prim.ListConstruct %int4_5777, %int4096_5778 : (!torch.int, !torch.int) -> !torch.list - %4664 = torch.aten.view %4654, %4663 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4665 = torch.aten.mm %4664, %4662 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %5452 = torch.prim.ListConstruct %int4_5752, %int1_5753, %int4096_5754 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5453 = torch.aten.view %5451, %5452 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_5755 = torch.constant.int -2 + %int-1_5756 = torch.constant.int -1 + %5454 = torch.aten.transpose.int %326, %int-2_5755, %int-1_5756 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_5757 = torch.constant.int 5 + %5455 = torch.prims.convert_element_type %5454, %int5_5757 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_5758 = torch.constant.int 4 + %int4096_5759 = torch.constant.int 4096 + %5456 = torch.prim.ListConstruct %int4_5758, %int4096_5759 : (!torch.int, !torch.int) -> !torch.list + %5457 = torch.aten.view %5446, %5456 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5458 = torch.aten.mm %5457, %5455 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_5760 = torch.constant.int 4 + %int1_5761 = torch.constant.int 1 + %int1024_5762 = torch.constant.int 1024 + %5459 = torch.prim.ListConstruct %int4_5760, %int1_5761, %int1024_5762 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5460 = torch.aten.view %5458, %5459 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_5763 = torch.constant.int -2 + %int-1_5764 = torch.constant.int -1 + %5461 = torch.aten.transpose.int %327, %int-2_5763, %int-1_5764 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_5765 = torch.constant.int 5 + %5462 = torch.prims.convert_element_type %5461, %int5_5765 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_5766 = torch.constant.int 4 + %int4096_5767 = torch.constant.int 4096 + %5463 = torch.prim.ListConstruct %int4_5766, %int4096_5767 : (!torch.int, !torch.int) -> !torch.list + %5464 = torch.aten.view %5446, %5463 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5465 = torch.aten.mm %5464, %5462 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_5768 = torch.constant.int 4 + %int1_5769 = torch.constant.int 1 + %int1024_5770 = torch.constant.int 1024 + %5466 = torch.prim.ListConstruct %int4_5768, %int1_5769, %int1024_5770 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5467 = torch.aten.view %5465, %5466 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_5771 = torch.constant.int 4 + %int1_5772 = torch.constant.int 1 + %int32_5773 = torch.constant.int 32 + %int128_5774 = torch.constant.int 128 + %5468 = torch.prim.ListConstruct %int4_5771, %int1_5772, %int32_5773, %int128_5774 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5469 = torch.aten.view %5453, %5468 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_5775 = torch.constant.int 4 + %int1_5776 = torch.constant.int 1 + %int8_5777 = torch.constant.int 8 + %int128_5778 = torch.constant.int 128 + %5470 = torch.prim.ListConstruct %int4_5775, %int1_5776, %int8_5777, %int128_5778 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5471 = torch.aten.view %5460, %5470 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> %int4_5779 = torch.constant.int 4 %int1_5780 = torch.constant.int 1 - %int14336_5781 = torch.constant.int 14336 - %4666 = torch.prim.ListConstruct %int4_5779, %int1_5780, %int14336_5781 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4667 = torch.aten.view %4665, %4666 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %4668 = torch.aten.mul.Tensor %4661, %4667 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_5782 = torch.constant.int -2 - %int-1_5783 = torch.constant.int -1 - %4669 = torch.aten.transpose.int %231, %int-2_5782, %int-1_5783 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_5784 = torch.constant.int 4 - %int14336_5785 = torch.constant.int 14336 - %4670 = torch.prim.ListConstruct %int4_5784, %int14336_5785 : (!torch.int, !torch.int) -> !torch.list - %4671 = torch.aten.view %4668, %4670 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %4672 = torch.aten.mm %4671, %4669 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_5786 = torch.constant.int 4 - %int1_5787 = torch.constant.int 1 - %int4096_5788 = torch.constant.int 4096 - %4673 = torch.prim.ListConstruct %int4_5786, %int1_5787, %int4096_5788 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4674 = torch.aten.view %4672, %4673 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_5789 = torch.constant.int 1 - %4675 = torch.aten.add.Tensor %4644, %4674, %int1_5789 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_5790 = torch.constant.int 6 - %4676 = torch.prims.convert_element_type %4675, %int6_5790 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_5791 = torch.constant.int 2 - %4677 = torch.aten.pow.Tensor_Scalar %4676, %int2_5791 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_5792 = torch.constant.int -1 - %4678 = torch.prim.ListConstruct %int-1_5792 : (!torch.int) -> !torch.list - %true_5793 = torch.constant.bool true - %none_5794 = torch.constant.none - %4679 = torch.aten.mean.dim %4677, %4678, %true_5793, %none_5794 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_5795 = torch.constant.float 9.9999997473787516E-6 - %int1_5796 = torch.constant.int 1 - %4680 = torch.aten.add.Scalar %4679, %float9.999990e-06_5795, %int1_5796 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %4681 = torch.aten.rsqrt %4680 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %4682 = torch.aten.mul.Tensor %4676, %4681 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_5797 = torch.constant.int 5 - %4683 = torch.prims.convert_element_type %4682, %int5_5797 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %4684 = torch.aten.mul.Tensor %232, %4683 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_5798 = torch.constant.int 5 - %4685 = torch.prims.convert_element_type %4684, %int5_5798 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_5799 = torch.constant.int -2 - %int-1_5800 = torch.constant.int -1 - %4686 = torch.aten.transpose.int %233, %int-2_5799, %int-1_5800 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_5801 = torch.constant.int 4 - %int4096_5802 = torch.constant.int 4096 - %4687 = torch.prim.ListConstruct %int4_5801, %int4096_5802 : (!torch.int, !torch.int) -> !torch.list - %4688 = torch.aten.view %4685, %4687 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4689 = torch.aten.mm %4688, %4686 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_5803 = torch.constant.int 4 - %int1_5804 = torch.constant.int 1 - %int4096_5805 = torch.constant.int 4096 - %4690 = torch.prim.ListConstruct %int4_5803, %int1_5804, %int4096_5805 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4691 = torch.aten.view %4689, %4690 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_5806 = torch.constant.int -2 + %int8_5781 = torch.constant.int 8 + %int128_5782 = torch.constant.int 128 + %5472 = torch.prim.ListConstruct %int4_5779, %int1_5780, %int8_5781, %int128_5782 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5473 = torch.aten.view %5467, %5472 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_5783 = torch.constant.int 1 + %int2_5784 = torch.constant.int 2 + %5474 = torch.aten.transpose.int %5469, %int1_5783, %int2_5784 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %5475 = torch.aten.mul.Tensor %5474, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_5785 = torch.constant.int 3 + %int0_5786 = torch.constant.int 0 + %int64_5787 = torch.constant.int 64 + %int1_5788 = torch.constant.int 1 + %5476 = torch.aten.slice.Tensor %5474, %int3_5785, %int0_5786, %int64_5787, %int1_5788 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_5789 = torch.constant.int 3 + %int64_5790 = torch.constant.int 64 + %int9223372036854775807_5791 = torch.constant.int 9223372036854775807 + %int1_5792 = torch.constant.int 1 + %5477 = torch.aten.slice.Tensor %5474, %int3_5789, %int64_5790, %int9223372036854775807_5791, %int1_5792 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %5478 = torch.aten.neg %5477 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %5479 = torch.prim.ListConstruct %5478, %5476 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_5793 = torch.constant.int -1 + %5480 = torch.aten.cat %5479, %int-1_5793 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %5481 = torch.aten.mul.Tensor %5480, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_5794 = torch.constant.int 1 + %5482 = torch.aten.add.Tensor %5475, %5481, %int1_5794 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_5795 = torch.constant.int 1 + %int2_5796 = torch.constant.int 2 + %5483 = torch.aten.transpose.int %5482, %int1_5795, %int2_5796 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_5797 = torch.constant.int 1 + %int2_5798 = torch.constant.int 2 + %5484 = torch.aten.transpose.int %5471, %int1_5797, %int2_5798 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %5485 = torch.aten.mul.Tensor %5484, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_5799 = torch.constant.int 3 + %int0_5800 = torch.constant.int 0 + %int64_5801 = torch.constant.int 64 + %int1_5802 = torch.constant.int 1 + %5486 = torch.aten.slice.Tensor %5484, %int3_5799, %int0_5800, %int64_5801, %int1_5802 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_5803 = torch.constant.int 3 + %int64_5804 = torch.constant.int 64 + %int9223372036854775807_5805 = torch.constant.int 9223372036854775807 + %int1_5806 = torch.constant.int 1 + %5487 = torch.aten.slice.Tensor %5484, %int3_5803, %int64_5804, %int9223372036854775807_5805, %int1_5806 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %5488 = torch.aten.neg %5487 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %5489 = torch.prim.ListConstruct %5488, %5486 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list %int-1_5807 = torch.constant.int -1 - %4692 = torch.aten.transpose.int %234, %int-2_5806, %int-1_5807 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_5808 = torch.constant.int 4 - %int4096_5809 = torch.constant.int 4096 - %4693 = torch.prim.ListConstruct %int4_5808, %int4096_5809 : (!torch.int, !torch.int) -> !torch.list - %4694 = torch.aten.view %4685, %4693 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4695 = torch.aten.mm %4694, %4692 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_5810 = torch.constant.int 4 - %int1_5811 = torch.constant.int 1 - %int1024_5812 = torch.constant.int 1024 - %4696 = torch.prim.ListConstruct %int4_5810, %int1_5811, %int1024_5812 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4697 = torch.aten.view %4695, %4696 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_5813 = torch.constant.int -2 - %int-1_5814 = torch.constant.int -1 - %4698 = torch.aten.transpose.int %235, %int-2_5813, %int-1_5814 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %5490 = torch.aten.cat %5489, %int-1_5807 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %5491 = torch.aten.mul.Tensor %5490, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_5808 = torch.constant.int 1 + %5492 = torch.aten.add.Tensor %5485, %5491, %int1_5808 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_5809 = torch.constant.int 1 + %int2_5810 = torch.constant.int 2 + %5493 = torch.aten.transpose.int %5492, %int1_5809, %int2_5810 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_5811 = torch.constant.int 32 + %5494 = torch.aten.floor_divide.Scalar %arg2, %int32_5811 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_5812 = torch.constant.int 1 + %5495 = torch.aten.unsqueeze %5494, %int1_5812 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_5813 = torch.constant.int 1 + %false_5814 = torch.constant.bool false + %5496 = torch.aten.gather %arg3, %int1_5813, %5495, %false_5814 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> %int4_5815 = torch.constant.int 4 - %int4096_5816 = torch.constant.int 4096 - %4699 = torch.prim.ListConstruct %int4_5815, %int4096_5816 : (!torch.int, !torch.int) -> !torch.list - %4700 = torch.aten.view %4685, %4699 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4701 = torch.aten.mm %4700, %4698 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_5817 = torch.constant.int 4 - %int1_5818 = torch.constant.int 1 - %int1024_5819 = torch.constant.int 1024 - %4702 = torch.prim.ListConstruct %int4_5817, %int1_5818, %int1024_5819 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4703 = torch.aten.view %4701, %4702 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_5820 = torch.constant.int 4 + %int1_5816 = torch.constant.int 1 + %int1_5817 = torch.constant.int 1 + %5497 = torch.prim.ListConstruct %int4_5815, %int1_5816, %int1_5817 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5498 = torch.aten.view %5496, %5497 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_5818 = torch.constant.int 32 + %5499 = torch.aten.remainder.Scalar %arg2, %int32_5818 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_5819 = torch.constant.int 4 + %int1_5820 = torch.constant.int 1 %int1_5821 = torch.constant.int 1 - %int32_5822 = torch.constant.int 32 - %int128_5823 = torch.constant.int 128 - %4704 = torch.prim.ListConstruct %int4_5820, %int1_5821, %int32_5822, %int128_5823 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4705 = torch.aten.view %4691, %4704 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_5824 = torch.constant.int 4 - %int1_5825 = torch.constant.int 1 - %int8_5826 = torch.constant.int 8 - %int128_5827 = torch.constant.int 128 - %4706 = torch.prim.ListConstruct %int4_5824, %int1_5825, %int8_5826, %int128_5827 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4707 = torch.aten.view %4697, %4706 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_5828 = torch.constant.int 4 - %int1_5829 = torch.constant.int 1 - %int8_5830 = torch.constant.int 8 - %int128_5831 = torch.constant.int 128 - %4708 = torch.prim.ListConstruct %int4_5828, %int1_5829, %int8_5830, %int128_5831 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4709 = torch.aten.view %4703, %4708 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_5832 = torch.constant.int 6 - %4710 = torch.prims.convert_element_type %4705, %int6_5832 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %4711 = torch_c.to_builtin_tensor %4710 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %4712 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %4713 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%4711, %4712) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %4714 = torch_c.from_builtin_tensor %4713 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_5833 = torch.constant.int 5 - %4715 = torch.prims.convert_element_type %4714, %int5_5833 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_5834 = torch.constant.int 6 - %4716 = torch.prims.convert_element_type %4707, %int6_5834 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %4717 = torch_c.to_builtin_tensor %4716 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %4718 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %4719 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%4717, %4718) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %4720 = torch_c.from_builtin_tensor %4719 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_5835 = torch.constant.int 5 - %4721 = torch.prims.convert_element_type %4720, %int5_5835 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_5836 = torch.constant.int 32 - %4722 = torch.aten.floor_divide.Scalar %arg2, %int32_5836 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %5500 = torch.prim.ListConstruct %int4_5819, %int1_5820, %int1_5821 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5501 = torch.aten.view %5499, %5500 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_5822 = torch.constant.int 8 + %none_5823 = torch.constant.none + %none_5824 = torch.constant.none + %cpu_5825 = torch.constant.device "cpu" + %false_5826 = torch.constant.bool false + %5502 = torch.aten.arange %int8_5822, %none_5823, %none_5824, %cpu_5825, %false_5826 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_5827 = torch.constant.int 1 + %int1_5828 = torch.constant.int 1 + %int8_5829 = torch.constant.int 8 + %5503 = torch.prim.ListConstruct %int1_5827, %int1_5828, %int8_5829 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5504 = torch.aten.view %5502, %5503 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_5830 = torch.constant.none + %5505 = torch.aten.clone %328, %none_5830 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5506 = torch.aten.detach %5505 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5507 = torch.aten.detach %5506 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5508 = torch.aten.detach %5507 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_5831 = torch.constant.int 1 + %int1_5832 = torch.constant.int 1 + %int1_5833 = torch.constant.int 1 + %5509 = torch.prim.ListConstruct %int1_5831, %int1_5832, %int1_5833 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5510 = torch.aten.view %5508, %5509 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_5834 = torch.constant.int 32 + %5511 = torch.aten.mul.Scalar %5498, %int32_5834 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int23 = torch.constant.int 23 + %int1_5835 = torch.constant.int 1 + %5512 = torch.aten.add.Scalar %5511, %int23, %int1_5835 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_5836 = torch.constant.int 2 + %5513 = torch.aten.mul.Scalar %5512, %int2_5836 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_5837 = torch.constant.int 1 - %4723 = torch.aten.unsqueeze %4722, %int1_5837 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5838 = torch.constant.int 1 - %false_5839 = torch.constant.bool false - %4724 = torch.aten.gather %arg3, %int1_5838, %4723, %false_5839 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %5514 = torch.aten.add.Tensor %5513, %5510, %int1_5837 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_5838 = torch.constant.int 8 + %5515 = torch.aten.mul.Scalar %5514, %int8_5838 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_5839 = torch.constant.int 1 + %5516 = torch.aten.add.Tensor %5515, %5504, %int1_5839 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int32_5840 = torch.constant.int 32 - %4725 = torch.aten.remainder.Scalar %arg2, %int32_5840 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %5517 = torch.aten.mul.Scalar %5516, %int32_5840 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_5841 = torch.constant.int 1 - %4726 = torch.aten.unsqueeze %4725, %int1_5841 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_5842 = torch.constant.none - %4727 = torch.aten.clone %236, %none_5842 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_5843 = torch.constant.int 0 - %4728 = torch.aten.unsqueeze %4727, %int0_5843 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_5844 = torch.constant.int 4 - %int1_5845 = torch.constant.int 1 - %4729 = torch.prim.ListConstruct %int4_5844, %int1_5845 : (!torch.int, !torch.int) -> !torch.list - %int1_5846 = torch.constant.int 1 - %int1_5847 = torch.constant.int 1 - %4730 = torch.prim.ListConstruct %int1_5846, %int1_5847 : (!torch.int, !torch.int) -> !torch.list - %int4_5848 = torch.constant.int 4 - %int0_5849 = torch.constant.int 0 - %cpu_5850 = torch.constant.device "cpu" - %false_5851 = torch.constant.bool false - %4731 = torch.aten.empty_strided %4729, %4730, %int4_5848, %int0_5849, %cpu_5850, %false_5851 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int21 = torch.constant.int 21 - %4732 = torch.aten.fill.Scalar %4731, %int21 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_5852 = torch.constant.int 4 - %int1_5853 = torch.constant.int 1 - %4733 = torch.prim.ListConstruct %int4_5852, %int1_5853 : (!torch.int, !torch.int) -> !torch.list - %4734 = torch.aten.repeat %4728, %4733 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_5854 = torch.constant.int 32 - %4735 = torch.aten.mul.Scalar %4724, %int32_5854 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5855 = torch.constant.int 1 - %4736 = torch.aten.add.Tensor %4735, %4732, %int1_5855 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_5856 = torch.constant.int 2 - %4737 = torch.aten.mul.Scalar %4736, %int2_5856 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5857 = torch.constant.int 1 - %4738 = torch.aten.add.Tensor %4737, %4734, %int1_5857 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_5858 = torch.constant.int 32 - %4739 = torch.aten.mul.Scalar %4738, %int32_5858 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5859 = torch.constant.int 1 - %4740 = torch.aten.add.Tensor %4739, %4726, %int1_5859 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_5860 = torch.constant.int 32 - %int2_5861 = torch.constant.int 2 - %int32_5862 = torch.constant.int 32 - %int8_5863 = torch.constant.int 8 - %int128_5864 = torch.constant.int 128 - %4741 = torch.prim.ListConstruct %437, %int32_5860, %int2_5861, %int32_5862, %int8_5863, %int128_5864 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4742 = torch.aten.view %4578, %4741 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4742, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5865 = torch.constant.int 32 - %4743 = torch.aten.mul.int %437, %int32_5865 : !torch.int, !torch.int -> !torch.int - %int2_5866 = torch.constant.int 2 - %4744 = torch.aten.mul.int %4743, %int2_5866 : !torch.int, !torch.int -> !torch.int - %int32_5867 = torch.constant.int 32 - %4745 = torch.aten.mul.int %4744, %int32_5867 : !torch.int, !torch.int -> !torch.int - %int8_5868 = torch.constant.int 8 - %int128_5869 = torch.constant.int 128 - %4746 = torch.prim.ListConstruct %4745, %int8_5868, %int128_5869 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4747 = torch.aten.view %4742, %4746 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4747, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %4748 = torch.prim.ListConstruct %4740 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_5870 = torch.constant.bool false - %4749 = torch.aten.index_put %4747, %4748, %4721, %false_5870 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4749, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_5871 = torch.constant.int 32 - %int2_5872 = torch.constant.int 2 + %5518 = torch.aten.add.Tensor %5517, %5501, %int1_5841 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_5842 = torch.constant.int 5 + %5519 = torch.prims.convert_element_type %5493, %int5_5842 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_5843 = torch.constant.int 32 + %int2_5844 = torch.constant.int 2 + %int8_5845 = torch.constant.int 8 + %int32_5846 = torch.constant.int 32 + %int128_5847 = torch.constant.int 128 + %5520 = torch.prim.ListConstruct %456, %int32_5843, %int2_5844, %int8_5845, %int32_5846, %int128_5847 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5521 = torch.aten.view %5341, %5520 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5521, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_5848 = torch.constant.int 128 + %5522 = torch.prim.ListConstruct %596, %int128_5848 : (!torch.int, !torch.int) -> !torch.list + %5523 = torch.aten.view %5521, %5522 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5523, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %5524 = torch.prim.ListConstruct %5518 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_5849 = torch.constant.bool false + %5525 = torch.aten.index_put %5523, %5524, %5519, %false_5849 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5525, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_5850 = torch.constant.int 32 + %int2_5851 = torch.constant.int 2 + %int8_5852 = torch.constant.int 8 + %int32_5853 = torch.constant.int 32 + %int128_5854 = torch.constant.int 128 + %5526 = torch.prim.ListConstruct %456, %int32_5850, %int2_5851, %int8_5852, %int32_5853, %int128_5854 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5527 = torch.aten.view %5525, %5526 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5527, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_5855 = torch.constant.int 2097152 + %5528 = torch.prim.ListConstruct %456, %int2097152_5855 : (!torch.int, !torch.int) -> !torch.list + %5529 = torch.aten.view %5527, %5528 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5529, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_5856 = torch.constant.int 32 + %int2_5857 = torch.constant.int 2 + %int8_5858 = torch.constant.int 8 + %int32_5859 = torch.constant.int 32 + %int128_5860 = torch.constant.int 128 + %5530 = torch.prim.ListConstruct %456, %int32_5856, %int2_5857, %int8_5858, %int32_5859, %int128_5860 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5531 = torch.aten.view %5529, %5530 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5531, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_5861 = torch.constant.int 128 + %5532 = torch.prim.ListConstruct %596, %int128_5861 : (!torch.int, !torch.int) -> !torch.list + %5533 = torch.aten.view %5531, %5532 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5533, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_5862 = torch.constant.none + %5534 = torch.aten.clone %329, %none_5862 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5535 = torch.aten.detach %5534 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5536 = torch.aten.detach %5535 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5537 = torch.aten.detach %5536 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_5863 = torch.constant.int 1 + %int1_5864 = torch.constant.int 1 + %int1_5865 = torch.constant.int 1 + %5538 = torch.prim.ListConstruct %int1_5863, %int1_5864, %int1_5865 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5539 = torch.aten.view %5537, %5538 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_5866 = torch.constant.int 32 + %5540 = torch.aten.mul.Scalar %5498, %int32_5866 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int23_5867 = torch.constant.int 23 + %int1_5868 = torch.constant.int 1 + %5541 = torch.aten.add.Scalar %5540, %int23_5867, %int1_5868 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_5869 = torch.constant.int 2 + %5542 = torch.aten.mul.Scalar %5541, %int2_5869 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_5870 = torch.constant.int 1 + %5543 = torch.aten.add.Tensor %5542, %5539, %int1_5870 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_5871 = torch.constant.int 8 + %5544 = torch.aten.mul.Scalar %5543, %int8_5871 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_5872 = torch.constant.int 1 + %5545 = torch.aten.add.Tensor %5544, %5504, %int1_5872 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int32_5873 = torch.constant.int 32 - %int8_5874 = torch.constant.int 8 - %int128_5875 = torch.constant.int 128 - %4750 = torch.prim.ListConstruct %437, %int32_5871, %int2_5872, %int32_5873, %int8_5874, %int128_5875 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4751 = torch.aten.view %4749, %4750 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4751, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5876 = torch.constant.int 2097152 - %4752 = torch.prim.ListConstruct %437, %int2097152_5876 : (!torch.int, !torch.int) -> !torch.list - %4753 = torch.aten.view %4751, %4752 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4753, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %5546 = torch.aten.mul.Scalar %5545, %int32_5873 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_5874 = torch.constant.int 1 + %5547 = torch.aten.add.Tensor %5546, %5501, %int1_5874 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_5875 = torch.constant.int 5 + %5548 = torch.prims.convert_element_type %5473, %int5_5875 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %5549 = torch.prim.ListConstruct %5547 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_5876 = torch.constant.bool false + %5550 = torch.aten.index_put %5533, %5549, %5548, %false_5876 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5550, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> %int32_5877 = torch.constant.int 32 %int2_5878 = torch.constant.int 2 - %int32_5879 = torch.constant.int 32 - %int8_5880 = torch.constant.int 8 + %int8_5879 = torch.constant.int 8 + %int32_5880 = torch.constant.int 32 %int128_5881 = torch.constant.int 128 - %4754 = torch.prim.ListConstruct %437, %int32_5877, %int2_5878, %int32_5879, %int8_5880, %int128_5881 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4755 = torch.aten.view %4753, %4754 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4755, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_5882 = torch.constant.int 8 - %int128_5883 = torch.constant.int 128 - %4756 = torch.prim.ListConstruct %4745, %int8_5882, %int128_5883 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4757 = torch.aten.view %4755, %4756 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4757, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_5884 = torch.constant.int 32 - %4758 = torch.aten.floor_divide.Scalar %arg2, %int32_5884 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_5885 = torch.constant.int 1 - %4759 = torch.aten.unsqueeze %4758, %int1_5885 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5886 = torch.constant.int 1 - %false_5887 = torch.constant.bool false - %4760 = torch.aten.gather %arg3, %int1_5886, %4759, %false_5887 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_5888 = torch.constant.int 32 - %4761 = torch.aten.remainder.Scalar %arg2, %int32_5888 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_5889 = torch.constant.int 1 - %4762 = torch.aten.unsqueeze %4761, %int1_5889 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_5890 = torch.constant.none - %4763 = torch.aten.clone %237, %none_5890 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_5891 = torch.constant.int 0 - %4764 = torch.aten.unsqueeze %4763, %int0_5891 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_5892 = torch.constant.int 4 - %int1_5893 = torch.constant.int 1 - %4765 = torch.prim.ListConstruct %int4_5892, %int1_5893 : (!torch.int, !torch.int) -> !torch.list - %int1_5894 = torch.constant.int 1 - %int1_5895 = torch.constant.int 1 - %4766 = torch.prim.ListConstruct %int1_5894, %int1_5895 : (!torch.int, !torch.int) -> !torch.list - %int4_5896 = torch.constant.int 4 + %5551 = torch.prim.ListConstruct %456, %int32_5877, %int2_5878, %int8_5879, %int32_5880, %int128_5881 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5552 = torch.aten.view %5550, %5551 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5552, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_5882 = torch.constant.int 2097152 + %5553 = torch.prim.ListConstruct %456, %int2097152_5882 : (!torch.int, !torch.int) -> !torch.list + %5554 = torch.aten.view %5552, %5553 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5554, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_5883 = torch.constant.none + %5555 = torch.aten.clone %330, %none_5883 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5556 = torch.aten.detach %5555 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5557 = torch.aten.detach %5556 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5558 = torch.aten.detach %5557 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_5884 = torch.constant.none + %5559 = torch.aten.clone %331, %none_5884 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5560 = torch.aten.detach %5559 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5561 = torch.aten.detach %5560 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5562 = torch.aten.detach %5561 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_5885 = torch.constant.none + %5563 = torch.aten.clone %332, %none_5885 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5564 = torch.aten.detach %5563 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5565 = torch.aten.detach %5564 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5566 = torch.aten.detach %5565 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_5886 = torch.constant.int 32 + %int2_5887 = torch.constant.int 2 + %int8_5888 = torch.constant.int 8 + %int32_5889 = torch.constant.int 32 + %int128_5890 = torch.constant.int 128 + %5567 = torch.prim.ListConstruct %456, %int32_5886, %int2_5887, %int8_5888, %int32_5889, %int128_5890 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5568 = torch.aten.view %5554, %5567 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5568, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %5569 = torch_c.to_builtin_tensor %5568 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %5570 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_5891 = tensor.cast %5570 : tensor<4x?xi64> to tensor + %5571 = torch_c.to_builtin_tensor %5558 : !torch.vtensor<[],si64> -> tensor + %5572 = torch_c.to_builtin_tensor %5562 : !torch.vtensor<[],si64> -> tensor + %5573 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%5569, %cast_5891, %5571, %5572) : (tensor, tensor, tensor, tensor) -> tensor + %cast_5892 = tensor.cast %5573 : tensor to tensor<4x?x8x32x128xf16> + %5574 = torch_c.from_builtin_tensor %cast_5892 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %5574, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %5575 = torch_c.to_builtin_tensor %5568 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %5576 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_5893 = tensor.cast %5576 : tensor<4x?xi64> to tensor + %5577 = torch_c.to_builtin_tensor %5558 : !torch.vtensor<[],si64> -> tensor + %5578 = torch_c.to_builtin_tensor %5566 : !torch.vtensor<[],si64> -> tensor + %5579 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%5575, %cast_5893, %5577, %5578) : (tensor, tensor, tensor, tensor) -> tensor + %cast_5894 = tensor.cast %5579 : tensor to tensor<4x?x8x32x128xf16> + %5580 = torch_c.from_builtin_tensor %cast_5894 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %5580, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_5895 = torch.constant.int 2 + %int3_5896 = torch.constant.int 3 + %5581 = torch.aten.transpose.int %5574, %int2_5895, %int3_5896 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5581, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int0_5897 = torch.constant.int 0 - %cpu_5898 = torch.constant.device "cpu" - %false_5899 = torch.constant.bool false - %4767 = torch.aten.empty_strided %4765, %4766, %int4_5896, %int0_5897, %cpu_5898, %false_5899 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int21_5900 = torch.constant.int 21 - %4768 = torch.aten.fill.Scalar %4767, %int21_5900 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_5901 = torch.constant.int 4 - %int1_5902 = torch.constant.int 1 - %4769 = torch.prim.ListConstruct %int4_5901, %int1_5902 : (!torch.int, !torch.int) -> !torch.list - %4770 = torch.aten.repeat %4764, %4769 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_5903 = torch.constant.int 32 - %4771 = torch.aten.mul.Scalar %4760, %int32_5903 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5904 = torch.constant.int 1 - %4772 = torch.aten.add.Tensor %4771, %4768, %int1_5904 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_5905 = torch.constant.int 2 - %4773 = torch.aten.mul.Scalar %4772, %int2_5905 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5906 = torch.constant.int 1 - %4774 = torch.aten.add.Tensor %4773, %4770, %int1_5906 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_5907 = torch.constant.int 32 - %4775 = torch.aten.mul.Scalar %4774, %int32_5907 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_5908 = torch.constant.int 1 - %4776 = torch.aten.add.Tensor %4775, %4762, %int1_5908 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %4777 = torch.prim.ListConstruct %4776 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_5909 = torch.constant.bool false - %4778 = torch.aten.index_put %4757, %4777, %4709, %false_5909 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4778, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_5910 = torch.constant.int 32 - %int2_5911 = torch.constant.int 2 - %int32_5912 = torch.constant.int 32 - %int8_5913 = torch.constant.int 8 - %int128_5914 = torch.constant.int 128 - %4779 = torch.prim.ListConstruct %437, %int32_5910, %int2_5911, %int32_5912, %int8_5913, %int128_5914 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4780 = torch.aten.view %4778, %4779 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4780, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_5915 = torch.constant.int 2097152 - %4781 = torch.prim.ListConstruct %437, %int2097152_5915 : (!torch.int, !torch.int) -> !torch.list - %4782 = torch.aten.view %4780, %4781 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4782, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_5916 = torch.constant.int 4 - %4783 = torch.prim.ListConstruct %int4_5916, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_5917 = torch.constant.int 1 - %4784 = torch.prim.ListConstruct %358, %int1_5917 : (!torch.int, !torch.int) -> !torch.list + %5582 = torch.aten.clone %5581, %int0_5897 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5582, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_5898 = torch.constant.int 4 + %int8_5899 = torch.constant.int 8 + %int128_5900 = torch.constant.int 128 + %5583 = torch.prim.ListConstruct %int4_5898, %457, %int8_5899, %int128_5900 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5584 = torch.aten._unsafe_view %5582, %5583 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5584, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_5901 = torch.constant.int 2 + %int3_5902 = torch.constant.int 3 + %5585 = torch.aten.transpose.int %5580, %int2_5901, %int3_5902 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5585, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_5903 = torch.constant.int 0 + %5586 = torch.aten.clone %5585, %int0_5903 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5586, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_5904 = torch.constant.int 4 + %int8_5905 = torch.constant.int 8 + %int128_5906 = torch.constant.int 128 + %5587 = torch.prim.ListConstruct %int4_5904, %457, %int8_5905, %int128_5906 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5588 = torch.aten._unsafe_view %5586, %5587 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5588, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_5907 = torch.constant.int -2 + %5589 = torch.aten.unsqueeze %5584, %int-2_5907 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5589, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_5908 = torch.constant.int 4 + %int8_5909 = torch.constant.int 8 + %int4_5910 = torch.constant.int 4 + %int128_5911 = torch.constant.int 128 + %5590 = torch.prim.ListConstruct %int4_5908, %457, %int8_5909, %int4_5910, %int128_5911 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_5912 = torch.constant.bool false + %5591 = torch.aten.expand %5589, %5590, %false_5912 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5591, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_5913 = torch.constant.int 0 + %5592 = torch.aten.clone %5591, %int0_5913 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5592, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_5914 = torch.constant.int 4 + %int32_5915 = torch.constant.int 32 + %int128_5916 = torch.constant.int 128 + %5593 = torch.prim.ListConstruct %int4_5914, %457, %int32_5915, %int128_5916 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5594 = torch.aten._unsafe_view %5592, %5593 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5594, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_5917 = torch.constant.int -2 + %5595 = torch.aten.unsqueeze %5588, %int-2_5917 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5595, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> %int4_5918 = torch.constant.int 4 - %int0_5919 = torch.constant.int 0 - %cpu_5920 = torch.constant.device "cpu" - %false_5921 = torch.constant.bool false - %4785 = torch.aten.empty_strided %4783, %4784, %int4_5918, %int0_5919, %cpu_5920, %false_5921 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4785, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int21_5922 = torch.constant.int 21 - %4786 = torch.aten.fill.Scalar %4785, %int21_5922 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4786, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_5923 = torch.constant.int 32 - %4787 = torch.aten.mul.Scalar %arg3, %int32_5923 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4787, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_5924 = torch.constant.int 1 - %4788 = torch.aten.add.Tensor %4787, %4786, %int1_5924 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4788, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_5925 = torch.constant.int 4 - %4789 = torch.aten.mul.int %int4_5925, %358 : !torch.int, !torch.int -> !torch.int - %4790 = torch.prim.ListConstruct %4789 : (!torch.int) -> !torch.list - %4791 = torch.aten.view %4788, %4790 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4791, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_5926 = torch.constant.int 32 - %int2_5927 = torch.constant.int 2 - %int32_5928 = torch.constant.int 32 - %int8_5929 = torch.constant.int 8 - %int128_5930 = torch.constant.int 128 - %4792 = torch.prim.ListConstruct %437, %int32_5926, %int2_5927, %int32_5928, %int8_5929, %int128_5930 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4793 = torch.aten.view %4782, %4792 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4793, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_5931 = torch.constant.int 32 - %4794 = torch.aten.mul.int %437, %int32_5931 : !torch.int, !torch.int -> !torch.int + %int8_5919 = torch.constant.int 8 + %int4_5920 = torch.constant.int 4 + %int128_5921 = torch.constant.int 128 + %5596 = torch.prim.ListConstruct %int4_5918, %457, %int8_5919, %int4_5920, %int128_5921 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_5922 = torch.constant.bool false + %5597 = torch.aten.expand %5595, %5596, %false_5922 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5597, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_5923 = torch.constant.int 0 + %5598 = torch.aten.clone %5597, %int0_5923 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5598, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_5924 = torch.constant.int 4 + %int32_5925 = torch.constant.int 32 + %int128_5926 = torch.constant.int 128 + %5599 = torch.prim.ListConstruct %int4_5924, %457, %int32_5925, %int128_5926 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5600 = torch.aten._unsafe_view %5598, %5599 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5600, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_5927 = torch.constant.int 1 + %int2_5928 = torch.constant.int 2 + %5601 = torch.aten.transpose.int %5483, %int1_5927, %int2_5928 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_5929 = torch.constant.int 1 + %int2_5930 = torch.constant.int 2 + %5602 = torch.aten.transpose.int %5594, %int1_5929, %int2_5930 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5602, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_5931 = torch.constant.int 1 %int2_5932 = torch.constant.int 2 - %int32_5933 = torch.constant.int 32 - %int8_5934 = torch.constant.int 8 - %int128_5935 = torch.constant.int 128 - %4795 = torch.prim.ListConstruct %4794, %int2_5932, %int32_5933, %int8_5934, %int128_5935 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4796 = torch.aten.view %4793, %4795 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %4796, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_5936 = torch.constant.int 0 - %4797 = torch.aten.index_select %4796, %int0_5936, %4791 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %4797, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_5937 = torch.constant.int 4 - %int2_5938 = torch.constant.int 2 - %int32_5939 = torch.constant.int 32 - %int8_5940 = torch.constant.int 8 - %int128_5941 = torch.constant.int 128 - %4798 = torch.prim.ListConstruct %int4_5937, %358, %int2_5938, %int32_5939, %int8_5940, %int128_5941 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4799 = torch.aten.view %4797, %4798 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4799, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_5942 = torch.constant.int 0 - %int0_5943 = torch.constant.int 0 - %int9223372036854775807_5944 = torch.constant.int 9223372036854775807 - %int1_5945 = torch.constant.int 1 - %4800 = torch.aten.slice.Tensor %4799, %int0_5942, %int0_5943, %int9223372036854775807_5944, %int1_5945 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4800, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_5946 = torch.constant.int 1 - %int0_5947 = torch.constant.int 0 - %int9223372036854775807_5948 = torch.constant.int 9223372036854775807 + %5603 = torch.aten.transpose.int %5600, %int1_5931, %int2_5932 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5603, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_5933 = torch.constant.float 0.000000e+00 + %false_5934 = torch.constant.bool false + %none_5935 = torch.constant.none + %5604:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5601, %5602, %5603, %float0.000000e00_5933, %false_5934, %470, %none_5935) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_5936 = torch.constant.int 1 + %int2_5937 = torch.constant.int 2 + %5605 = torch.aten.transpose.int %5604#0, %int1_5936, %int2_5937 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_5938 = torch.constant.int 4 + %int1_5939 = torch.constant.int 1 + %int4096_5940 = torch.constant.int 4096 + %5606 = torch.prim.ListConstruct %int4_5938, %int1_5939, %int4096_5940 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5607 = torch.aten.view %5605, %5606 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_5941 = torch.constant.int -2 + %int-1_5942 = torch.constant.int -1 + %5608 = torch.aten.transpose.int %333, %int-2_5941, %int-1_5942 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_5943 = torch.constant.int 5 + %5609 = torch.prims.convert_element_type %5608, %int5_5943 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_5944 = torch.constant.int 4 + %int4096_5945 = torch.constant.int 4096 + %5610 = torch.prim.ListConstruct %int4_5944, %int4096_5945 : (!torch.int, !torch.int) -> !torch.list + %5611 = torch.aten.view %5607, %5610 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5612 = torch.aten.mm %5611, %5609 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_5946 = torch.constant.int 4 + %int1_5947 = torch.constant.int 1 + %int4096_5948 = torch.constant.int 4096 + %5613 = torch.prim.ListConstruct %int4_5946, %int1_5947, %int4096_5948 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5614 = torch.aten.view %5612, %5613 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_5949 = torch.constant.int 1 - %4801 = torch.aten.slice.Tensor %4800, %int1_5946, %int0_5947, %int9223372036854775807_5948, %int1_5949 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4801, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_5950 = torch.constant.int 2 - %int0_5951 = torch.constant.int 0 - %4802 = torch.aten.select.int %4801, %int2_5950, %int0_5951 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4802, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_5952 = torch.constant.int 32 - %4803 = torch.aten.mul.int %358, %int32_5952 : !torch.int, !torch.int -> !torch.int - %int2_5953 = torch.constant.int 2 - %int0_5954 = torch.constant.int 0 - %int1_5955 = torch.constant.int 1 - %4804 = torch.aten.slice.Tensor %4802, %int2_5953, %int0_5954, %4803, %int1_5955 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4804, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_5956 = torch.constant.int 0 - %4805 = torch.aten.clone %4804, %int0_5956 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4805, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_5957 = torch.constant.int 1 - %4806 = torch.aten.size.int %4801, %int1_5957 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_5958 = torch.constant.int 32 - %4807 = torch.aten.mul.int %4806, %int32_5958 : !torch.int, !torch.int -> !torch.int - %int4_5959 = torch.constant.int 4 - %int8_5960 = torch.constant.int 8 - %int128_5961 = torch.constant.int 128 - %4808 = torch.prim.ListConstruct %int4_5959, %4807, %int8_5960, %int128_5961 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4809 = torch.aten._unsafe_view %4805, %4808 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4809, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_5962 = torch.constant.int 0 - %int0_5963 = torch.constant.int 0 - %int9223372036854775807_5964 = torch.constant.int 9223372036854775807 + %5615 = torch.aten.add.Tensor %5436, %5614, %int1_5949 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_5950 = torch.constant.int 6 + %5616 = torch.prims.convert_element_type %5615, %int6_5950 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_5951 = torch.constant.int 2 + %5617 = torch.aten.pow.Tensor_Scalar %5616, %int2_5951 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_5952 = torch.constant.int -1 + %5618 = torch.prim.ListConstruct %int-1_5952 : (!torch.int) -> !torch.list + %true_5953 = torch.constant.bool true + %none_5954 = torch.constant.none + %5619 = torch.aten.mean.dim %5617, %5618, %true_5953, %none_5954 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_5955 = torch.constant.float 9.9999997473787516E-6 + %int1_5956 = torch.constant.int 1 + %5620 = torch.aten.add.Scalar %5619, %float9.999990e-06_5955, %int1_5956 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %5621 = torch.aten.rsqrt %5620 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %5622 = torch.aten.mul.Tensor %5616, %5621 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_5957 = torch.constant.int 5 + %5623 = torch.prims.convert_element_type %5622, %int5_5957 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %5624 = torch.aten.mul.Tensor %334, %5623 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_5958 = torch.constant.int 5 + %5625 = torch.prims.convert_element_type %5624, %int5_5958 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_5959 = torch.constant.int -2 + %int-1_5960 = torch.constant.int -1 + %5626 = torch.aten.transpose.int %335, %int-2_5959, %int-1_5960 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_5961 = torch.constant.int 5 + %5627 = torch.prims.convert_element_type %5626, %int5_5961 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_5962 = torch.constant.int 4 + %int4096_5963 = torch.constant.int 4096 + %5628 = torch.prim.ListConstruct %int4_5962, %int4096_5963 : (!torch.int, !torch.int) -> !torch.list + %5629 = torch.aten.view %5625, %5628 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5630 = torch.aten.mm %5629, %5627 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_5964 = torch.constant.int 4 %int1_5965 = torch.constant.int 1 - %4810 = torch.aten.slice.Tensor %4809, %int0_5962, %int0_5963, %int9223372036854775807_5964, %int1_5965 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4810, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_5966 = torch.constant.int 0 - %int0_5967 = torch.constant.int 0 - %int9223372036854775807_5968 = torch.constant.int 9223372036854775807 - %int1_5969 = torch.constant.int 1 - %4811 = torch.aten.slice.Tensor %4799, %int0_5966, %int0_5967, %int9223372036854775807_5968, %int1_5969 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4811, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_5970 = torch.constant.int 1 - %int0_5971 = torch.constant.int 0 - %int9223372036854775807_5972 = torch.constant.int 9223372036854775807 + %int14336_5966 = torch.constant.int 14336 + %5631 = torch.prim.ListConstruct %int4_5964, %int1_5965, %int14336_5966 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5632 = torch.aten.view %5630, %5631 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %5633 = torch.aten.silu %5632 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_5967 = torch.constant.int -2 + %int-1_5968 = torch.constant.int -1 + %5634 = torch.aten.transpose.int %336, %int-2_5967, %int-1_5968 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_5969 = torch.constant.int 5 + %5635 = torch.prims.convert_element_type %5634, %int5_5969 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_5970 = torch.constant.int 4 + %int4096_5971 = torch.constant.int 4096 + %5636 = torch.prim.ListConstruct %int4_5970, %int4096_5971 : (!torch.int, !torch.int) -> !torch.list + %5637 = torch.aten.view %5625, %5636 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5638 = torch.aten.mm %5637, %5635 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_5972 = torch.constant.int 4 %int1_5973 = torch.constant.int 1 - %4812 = torch.aten.slice.Tensor %4811, %int1_5970, %int0_5971, %int9223372036854775807_5972, %int1_5973 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %4812, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_5974 = torch.constant.int 2 - %int1_5975 = torch.constant.int 1 - %4813 = torch.aten.select.int %4812, %int2_5974, %int1_5975 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4813, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_5976 = torch.constant.int 2 - %int0_5977 = torch.constant.int 0 - %int1_5978 = torch.constant.int 1 - %4814 = torch.aten.slice.Tensor %4813, %int2_5976, %int0_5977, %4803, %int1_5978 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4814, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_5979 = torch.constant.int 0 - %4815 = torch.aten.clone %4814, %int0_5979 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %4815, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_5980 = torch.constant.int 1 - %4816 = torch.aten.size.int %4812, %int1_5980 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_5981 = torch.constant.int 32 - %4817 = torch.aten.mul.int %4816, %int32_5981 : !torch.int, !torch.int -> !torch.int - %int4_5982 = torch.constant.int 4 - %int8_5983 = torch.constant.int 8 - %int128_5984 = torch.constant.int 128 - %4818 = torch.prim.ListConstruct %int4_5982, %4817, %int8_5983, %int128_5984 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4819 = torch.aten._unsafe_view %4815, %4818 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4819, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_5985 = torch.constant.int 0 - %int0_5986 = torch.constant.int 0 - %int9223372036854775807_5987 = torch.constant.int 9223372036854775807 - %int1_5988 = torch.constant.int 1 - %4820 = torch.aten.slice.Tensor %4819, %int0_5985, %int0_5986, %int9223372036854775807_5987, %int1_5988 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %4820, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_5989 = torch.constant.int -2 - %4821 = torch.aten.unsqueeze %4810, %int-2_5989 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4821, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int14336_5974 = torch.constant.int 14336 + %5639 = torch.prim.ListConstruct %int4_5972, %int1_5973, %int14336_5974 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5640 = torch.aten.view %5638, %5639 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %5641 = torch.aten.mul.Tensor %5633, %5640 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_5975 = torch.constant.int -2 + %int-1_5976 = torch.constant.int -1 + %5642 = torch.aten.transpose.int %337, %int-2_5975, %int-1_5976 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_5977 = torch.constant.int 5 + %5643 = torch.prims.convert_element_type %5642, %int5_5977 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_5978 = torch.constant.int 4 + %int14336_5979 = torch.constant.int 14336 + %5644 = torch.prim.ListConstruct %int4_5978, %int14336_5979 : (!torch.int, !torch.int) -> !torch.list + %5645 = torch.aten.view %5641, %5644 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %5646 = torch.aten.mm %5645, %5643 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_5980 = torch.constant.int 4 + %int1_5981 = torch.constant.int 1 + %int4096_5982 = torch.constant.int 4096 + %5647 = torch.prim.ListConstruct %int4_5980, %int1_5981, %int4096_5982 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5648 = torch.aten.view %5646, %5647 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_5983 = torch.constant.int 1 + %5649 = torch.aten.add.Tensor %5615, %5648, %int1_5983 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_5984 = torch.constant.int 6 + %5650 = torch.prims.convert_element_type %5649, %int6_5984 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_5985 = torch.constant.int 2 + %5651 = torch.aten.pow.Tensor_Scalar %5650, %int2_5985 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_5986 = torch.constant.int -1 + %5652 = torch.prim.ListConstruct %int-1_5986 : (!torch.int) -> !torch.list + %true_5987 = torch.constant.bool true + %none_5988 = torch.constant.none + %5653 = torch.aten.mean.dim %5651, %5652, %true_5987, %none_5988 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_5989 = torch.constant.float 9.9999997473787516E-6 %int1_5990 = torch.constant.int 1 - %4822 = torch.aten.size.int %4809, %int1_5990 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_5991 = torch.constant.int 4 - %int8_5992 = torch.constant.int 8 - %int4_5993 = torch.constant.int 4 - %int128_5994 = torch.constant.int 128 - %4823 = torch.prim.ListConstruct %int4_5991, %4822, %int8_5992, %int4_5993, %int128_5994 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_5995 = torch.constant.bool false - %4824 = torch.aten.expand %4821, %4823, %false_5995 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4824, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_5996 = torch.constant.int 0 - %4825 = torch.aten.clone %4824, %int0_5996 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4825, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_5997 = torch.constant.int 4 - %int32_5998 = torch.constant.int 32 - %int128_5999 = torch.constant.int 128 - %4826 = torch.prim.ListConstruct %int4_5997, %4822, %int32_5998, %int128_5999 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4827 = torch.aten._unsafe_view %4825, %4826 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4827, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_6000 = torch.constant.int -2 - %4828 = torch.aten.unsqueeze %4820, %int-2_6000 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %4828, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_6001 = torch.constant.int 1 - %4829 = torch.aten.size.int %4819, %int1_6001 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_6002 = torch.constant.int 4 - %int8_6003 = torch.constant.int 8 + %5654 = torch.aten.add.Scalar %5653, %float9.999990e-06_5989, %int1_5990 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %5655 = torch.aten.rsqrt %5654 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %5656 = torch.aten.mul.Tensor %5650, %5655 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_5991 = torch.constant.int 5 + %5657 = torch.prims.convert_element_type %5656, %int5_5991 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %5658 = torch.aten.mul.Tensor %338, %5657 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_5992 = torch.constant.int 5 + %5659 = torch.prims.convert_element_type %5658, %int5_5992 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_5993 = torch.constant.int -2 + %int-1_5994 = torch.constant.int -1 + %5660 = torch.aten.transpose.int %339, %int-2_5993, %int-1_5994 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_5995 = torch.constant.int 5 + %5661 = torch.prims.convert_element_type %5660, %int5_5995 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_5996 = torch.constant.int 4 + %int4096_5997 = torch.constant.int 4096 + %5662 = torch.prim.ListConstruct %int4_5996, %int4096_5997 : (!torch.int, !torch.int) -> !torch.list + %5663 = torch.aten.view %5659, %5662 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5664 = torch.aten.mm %5663, %5661 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_5998 = torch.constant.int 4 + %int1_5999 = torch.constant.int 1 + %int4096_6000 = torch.constant.int 4096 + %5665 = torch.prim.ListConstruct %int4_5998, %int1_5999, %int4096_6000 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5666 = torch.aten.view %5664, %5665 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_6001 = torch.constant.int -2 + %int-1_6002 = torch.constant.int -1 + %5667 = torch.aten.transpose.int %340, %int-2_6001, %int-1_6002 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6003 = torch.constant.int 5 + %5668 = torch.prims.convert_element_type %5667, %int5_6003 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> %int4_6004 = torch.constant.int 4 - %int128_6005 = torch.constant.int 128 - %4830 = torch.prim.ListConstruct %int4_6002, %4829, %int8_6003, %int4_6004, %int128_6005 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_6006 = torch.constant.bool false - %4831 = torch.aten.expand %4828, %4830, %false_6006 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4831, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_6007 = torch.constant.int 0 - %4832 = torch.aten.clone %4831, %int0_6007 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %4832, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_6008 = torch.constant.int 4 - %int32_6009 = torch.constant.int 32 - %int128_6010 = torch.constant.int 128 - %4833 = torch.prim.ListConstruct %int4_6008, %4829, %int32_6009, %int128_6010 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4834 = torch.aten._unsafe_view %4832, %4833 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %4834, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_6011 = torch.constant.int 1 - %int2_6012 = torch.constant.int 2 - %4835 = torch.aten.transpose.int %4715, %int1_6011, %int2_6012 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_6013 = torch.constant.int 1 - %int2_6014 = torch.constant.int 2 - %4836 = torch.aten.transpose.int %4827, %int1_6013, %int2_6014 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4836, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int4096_6005 = torch.constant.int 4096 + %5669 = torch.prim.ListConstruct %int4_6004, %int4096_6005 : (!torch.int, !torch.int) -> !torch.list + %5670 = torch.aten.view %5659, %5669 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5671 = torch.aten.mm %5670, %5668 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_6006 = torch.constant.int 4 + %int1_6007 = torch.constant.int 1 + %int1024_6008 = torch.constant.int 1024 + %5672 = torch.prim.ListConstruct %int4_6006, %int1_6007, %int1024_6008 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5673 = torch.aten.view %5671, %5672 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_6009 = torch.constant.int -2 + %int-1_6010 = torch.constant.int -1 + %5674 = torch.aten.transpose.int %341, %int-2_6009, %int-1_6010 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6011 = torch.constant.int 5 + %5675 = torch.prims.convert_element_type %5674, %int5_6011 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_6012 = torch.constant.int 4 + %int4096_6013 = torch.constant.int 4096 + %5676 = torch.prim.ListConstruct %int4_6012, %int4096_6013 : (!torch.int, !torch.int) -> !torch.list + %5677 = torch.aten.view %5659, %5676 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5678 = torch.aten.mm %5677, %5675 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_6014 = torch.constant.int 4 %int1_6015 = torch.constant.int 1 - %int2_6016 = torch.constant.int 2 - %4837 = torch.aten.transpose.int %4834, %int1_6015, %int2_6016 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %4837, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_6017 = torch.constant.float 0.000000e+00 - %false_6018 = torch.constant.bool false - %none_6019 = torch.constant.none - %4838:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%4835, %4836, %4837, %float0.000000e00_6017, %false_6018, %368, %none_6019) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_6020 = torch.constant.int 1 - %int2_6021 = torch.constant.int 2 - %4839 = torch.aten.transpose.int %4838#0, %int1_6020, %int2_6021 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_6022 = torch.constant.int 4 - %int1_6023 = torch.constant.int 1 - %int4096_6024 = torch.constant.int 4096 - %4840 = torch.prim.ListConstruct %int4_6022, %int1_6023, %int4096_6024 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4841 = torch.aten.view %4839, %4840 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_6025 = torch.constant.int -2 - %int-1_6026 = torch.constant.int -1 - %4842 = torch.aten.transpose.int %238, %int-2_6025, %int-1_6026 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6027 = torch.constant.int 4 - %int4096_6028 = torch.constant.int 4096 - %4843 = torch.prim.ListConstruct %int4_6027, %int4096_6028 : (!torch.int, !torch.int) -> !torch.list - %4844 = torch.aten.view %4841, %4843 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4845 = torch.aten.mm %4844, %4842 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_6029 = torch.constant.int 4 - %int1_6030 = torch.constant.int 1 - %int4096_6031 = torch.constant.int 4096 - %4846 = torch.prim.ListConstruct %int4_6029, %int1_6030, %int4096_6031 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4847 = torch.aten.view %4845, %4846 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_6032 = torch.constant.int 1 - %4848 = torch.aten.add.Tensor %4675, %4847, %int1_6032 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_6033 = torch.constant.int 6 - %4849 = torch.prims.convert_element_type %4848, %int6_6033 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_6034 = torch.constant.int 2 - %4850 = torch.aten.pow.Tensor_Scalar %4849, %int2_6034 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_6035 = torch.constant.int -1 - %4851 = torch.prim.ListConstruct %int-1_6035 : (!torch.int) -> !torch.list - %true_6036 = torch.constant.bool true - %none_6037 = torch.constant.none - %4852 = torch.aten.mean.dim %4850, %4851, %true_6036, %none_6037 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_6038 = torch.constant.float 9.9999997473787516E-6 - %int1_6039 = torch.constant.int 1 - %4853 = torch.aten.add.Scalar %4852, %float9.999990e-06_6038, %int1_6039 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %4854 = torch.aten.rsqrt %4853 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %4855 = torch.aten.mul.Tensor %4849, %4854 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_6040 = torch.constant.int 5 - %4856 = torch.prims.convert_element_type %4855, %int5_6040 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %4857 = torch.aten.mul.Tensor %239, %4856 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_6041 = torch.constant.int 5 - %4858 = torch.prims.convert_element_type %4857, %int5_6041 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_6042 = torch.constant.int -2 - %int-1_6043 = torch.constant.int -1 - %4859 = torch.aten.transpose.int %240, %int-2_6042, %int-1_6043 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6044 = torch.constant.int 4 - %int4096_6045 = torch.constant.int 4096 - %4860 = torch.prim.ListConstruct %int4_6044, %int4096_6045 : (!torch.int, !torch.int) -> !torch.list - %4861 = torch.aten.view %4858, %4860 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4862 = torch.aten.mm %4861, %4859 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_6046 = torch.constant.int 4 - %int1_6047 = torch.constant.int 1 - %int14336_6048 = torch.constant.int 14336 - %4863 = torch.prim.ListConstruct %int4_6046, %int1_6047, %int14336_6048 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4864 = torch.aten.view %4862, %4863 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %4865 = torch.aten.silu %4864 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_6049 = torch.constant.int -2 - %int-1_6050 = torch.constant.int -1 - %4866 = torch.aten.transpose.int %241, %int-2_6049, %int-1_6050 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6051 = torch.constant.int 4 - %int4096_6052 = torch.constant.int 4096 - %4867 = torch.prim.ListConstruct %int4_6051, %int4096_6052 : (!torch.int, !torch.int) -> !torch.list - %4868 = torch.aten.view %4858, %4867 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4869 = torch.aten.mm %4868, %4866 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_6053 = torch.constant.int 4 + %int1024_6016 = torch.constant.int 1024 + %5679 = torch.prim.ListConstruct %int4_6014, %int1_6015, %int1024_6016 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5680 = torch.aten.view %5678, %5679 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_6017 = torch.constant.int 4 + %int1_6018 = torch.constant.int 1 + %int32_6019 = torch.constant.int 32 + %int128_6020 = torch.constant.int 128 + %5681 = torch.prim.ListConstruct %int4_6017, %int1_6018, %int32_6019, %int128_6020 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5682 = torch.aten.view %5666, %5681 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_6021 = torch.constant.int 4 + %int1_6022 = torch.constant.int 1 + %int8_6023 = torch.constant.int 8 + %int128_6024 = torch.constant.int 128 + %5683 = torch.prim.ListConstruct %int4_6021, %int1_6022, %int8_6023, %int128_6024 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5684 = torch.aten.view %5673, %5683 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_6025 = torch.constant.int 4 + %int1_6026 = torch.constant.int 1 + %int8_6027 = torch.constant.int 8 + %int128_6028 = torch.constant.int 128 + %5685 = torch.prim.ListConstruct %int4_6025, %int1_6026, %int8_6027, %int128_6028 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5686 = torch.aten.view %5680, %5685 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_6029 = torch.constant.int 1 + %int2_6030 = torch.constant.int 2 + %5687 = torch.aten.transpose.int %5682, %int1_6029, %int2_6030 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %5688 = torch.aten.mul.Tensor %5687, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_6031 = torch.constant.int 3 + %int0_6032 = torch.constant.int 0 + %int64_6033 = torch.constant.int 64 + %int1_6034 = torch.constant.int 1 + %5689 = torch.aten.slice.Tensor %5687, %int3_6031, %int0_6032, %int64_6033, %int1_6034 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_6035 = torch.constant.int 3 + %int64_6036 = torch.constant.int 64 + %int9223372036854775807_6037 = torch.constant.int 9223372036854775807 + %int1_6038 = torch.constant.int 1 + %5690 = torch.aten.slice.Tensor %5687, %int3_6035, %int64_6036, %int9223372036854775807_6037, %int1_6038 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %5691 = torch.aten.neg %5690 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %5692 = torch.prim.ListConstruct %5691, %5689 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_6039 = torch.constant.int -1 + %5693 = torch.aten.cat %5692, %int-1_6039 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %5694 = torch.aten.mul.Tensor %5693, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_6040 = torch.constant.int 1 + %5695 = torch.aten.add.Tensor %5688, %5694, %int1_6040 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_6041 = torch.constant.int 1 + %int2_6042 = torch.constant.int 2 + %5696 = torch.aten.transpose.int %5695, %int1_6041, %int2_6042 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_6043 = torch.constant.int 1 + %int2_6044 = torch.constant.int 2 + %5697 = torch.aten.transpose.int %5684, %int1_6043, %int2_6044 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %5698 = torch.aten.mul.Tensor %5697, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_6045 = torch.constant.int 3 + %int0_6046 = torch.constant.int 0 + %int64_6047 = torch.constant.int 64 + %int1_6048 = torch.constant.int 1 + %5699 = torch.aten.slice.Tensor %5697, %int3_6045, %int0_6046, %int64_6047, %int1_6048 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_6049 = torch.constant.int 3 + %int64_6050 = torch.constant.int 64 + %int9223372036854775807_6051 = torch.constant.int 9223372036854775807 + %int1_6052 = torch.constant.int 1 + %5700 = torch.aten.slice.Tensor %5697, %int3_6049, %int64_6050, %int9223372036854775807_6051, %int1_6052 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %5701 = torch.aten.neg %5700 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %5702 = torch.prim.ListConstruct %5701, %5699 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_6053 = torch.constant.int -1 + %5703 = torch.aten.cat %5702, %int-1_6053 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %5704 = torch.aten.mul.Tensor %5703, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> %int1_6054 = torch.constant.int 1 - %int14336_6055 = torch.constant.int 14336 - %4870 = torch.prim.ListConstruct %int4_6053, %int1_6054, %int14336_6055 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4871 = torch.aten.view %4869, %4870 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %4872 = torch.aten.mul.Tensor %4865, %4871 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_6056 = torch.constant.int -2 - %int-1_6057 = torch.constant.int -1 - %4873 = torch.aten.transpose.int %242, %int-2_6056, %int-1_6057 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_6058 = torch.constant.int 4 - %int14336_6059 = torch.constant.int 14336 - %4874 = torch.prim.ListConstruct %int4_6058, %int14336_6059 : (!torch.int, !torch.int) -> !torch.list - %4875 = torch.aten.view %4872, %4874 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %4876 = torch.aten.mm %4875, %4873 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_6060 = torch.constant.int 4 - %int1_6061 = torch.constant.int 1 - %int4096_6062 = torch.constant.int 4096 - %4877 = torch.prim.ListConstruct %int4_6060, %int1_6061, %int4096_6062 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4878 = torch.aten.view %4876, %4877 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %5705 = torch.aten.add.Tensor %5698, %5704, %int1_6054 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_6055 = torch.constant.int 1 + %int2_6056 = torch.constant.int 2 + %5706 = torch.aten.transpose.int %5705, %int1_6055, %int2_6056 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_6057 = torch.constant.int 32 + %5707 = torch.aten.floor_divide.Scalar %arg2, %int32_6057 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_6058 = torch.constant.int 1 + %5708 = torch.aten.unsqueeze %5707, %int1_6058 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_6059 = torch.constant.int 1 + %false_6060 = torch.constant.bool false + %5709 = torch.aten.gather %arg3, %int1_6059, %5708, %false_6060 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_6061 = torch.constant.int 4 + %int1_6062 = torch.constant.int 1 %int1_6063 = torch.constant.int 1 - %4879 = torch.aten.add.Tensor %4848, %4878, %int1_6063 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_6064 = torch.constant.int 6 - %4880 = torch.prims.convert_element_type %4879, %int6_6064 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_6065 = torch.constant.int 2 - %4881 = torch.aten.pow.Tensor_Scalar %4880, %int2_6065 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_6066 = torch.constant.int -1 - %4882 = torch.prim.ListConstruct %int-1_6066 : (!torch.int) -> !torch.list - %true_6067 = torch.constant.bool true - %none_6068 = torch.constant.none - %4883 = torch.aten.mean.dim %4881, %4882, %true_6067, %none_6068 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_6069 = torch.constant.float 9.9999997473787516E-6 - %int1_6070 = torch.constant.int 1 - %4884 = torch.aten.add.Scalar %4883, %float9.999990e-06_6069, %int1_6070 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %4885 = torch.aten.rsqrt %4884 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %4886 = torch.aten.mul.Tensor %4880, %4885 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_6071 = torch.constant.int 5 - %4887 = torch.prims.convert_element_type %4886, %int5_6071 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %4888 = torch.aten.mul.Tensor %243, %4887 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_6072 = torch.constant.int 5 - %4889 = torch.prims.convert_element_type %4888, %int5_6072 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_6073 = torch.constant.int -2 - %int-1_6074 = torch.constant.int -1 - %4890 = torch.aten.transpose.int %244, %int-2_6073, %int-1_6074 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6075 = torch.constant.int 4 - %int4096_6076 = torch.constant.int 4096 - %4891 = torch.prim.ListConstruct %int4_6075, %int4096_6076 : (!torch.int, !torch.int) -> !torch.list - %4892 = torch.aten.view %4889, %4891 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4893 = torch.aten.mm %4892, %4890 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_6077 = torch.constant.int 4 + %5710 = torch.prim.ListConstruct %int4_6061, %int1_6062, %int1_6063 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5711 = torch.aten.view %5709, %5710 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_6064 = torch.constant.int 32 + %5712 = torch.aten.remainder.Scalar %arg2, %int32_6064 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_6065 = torch.constant.int 4 + %int1_6066 = torch.constant.int 1 + %int1_6067 = torch.constant.int 1 + %5713 = torch.prim.ListConstruct %int4_6065, %int1_6066, %int1_6067 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5714 = torch.aten.view %5712, %5713 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_6068 = torch.constant.int 8 + %none_6069 = torch.constant.none + %none_6070 = torch.constant.none + %cpu_6071 = torch.constant.device "cpu" + %false_6072 = torch.constant.bool false + %5715 = torch.aten.arange %int8_6068, %none_6069, %none_6070, %cpu_6071, %false_6072 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_6073 = torch.constant.int 1 + %int1_6074 = torch.constant.int 1 + %int8_6075 = torch.constant.int 8 + %5716 = torch.prim.ListConstruct %int1_6073, %int1_6074, %int8_6075 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5717 = torch.aten.view %5715, %5716 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_6076 = torch.constant.none + %5718 = torch.aten.clone %342, %none_6076 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5719 = torch.aten.detach %5718 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5720 = torch.aten.detach %5719 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5721 = torch.aten.detach %5720 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_6077 = torch.constant.int 1 %int1_6078 = torch.constant.int 1 - %int4096_6079 = torch.constant.int 4096 - %4894 = torch.prim.ListConstruct %int4_6077, %int1_6078, %int4096_6079 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4895 = torch.aten.view %4893, %4894 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_6080 = torch.constant.int -2 - %int-1_6081 = torch.constant.int -1 - %4896 = torch.aten.transpose.int %245, %int-2_6080, %int-1_6081 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6082 = torch.constant.int 4 - %int4096_6083 = torch.constant.int 4096 - %4897 = torch.prim.ListConstruct %int4_6082, %int4096_6083 : (!torch.int, !torch.int) -> !torch.list - %4898 = torch.aten.view %4889, %4897 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4899 = torch.aten.mm %4898, %4896 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_6084 = torch.constant.int 4 + %int1_6079 = torch.constant.int 1 + %5722 = torch.prim.ListConstruct %int1_6077, %int1_6078, %int1_6079 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5723 = torch.aten.view %5721, %5722 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_6080 = torch.constant.int 32 + %5724 = torch.aten.mul.Scalar %5711, %int32_6080 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int24 = torch.constant.int 24 + %int1_6081 = torch.constant.int 1 + %5725 = torch.aten.add.Scalar %5724, %int24, %int1_6081 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_6082 = torch.constant.int 2 + %5726 = torch.aten.mul.Scalar %5725, %int2_6082 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_6083 = torch.constant.int 1 + %5727 = torch.aten.add.Tensor %5726, %5723, %int1_6083 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_6084 = torch.constant.int 8 + %5728 = torch.aten.mul.Scalar %5727, %int8_6084 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_6085 = torch.constant.int 1 - %int1024_6086 = torch.constant.int 1024 - %4900 = torch.prim.ListConstruct %int4_6084, %int1_6085, %int1024_6086 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4901 = torch.aten.view %4899, %4900 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_6087 = torch.constant.int -2 - %int-1_6088 = torch.constant.int -1 - %4902 = torch.aten.transpose.int %246, %int-2_6087, %int-1_6088 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6089 = torch.constant.int 4 - %int4096_6090 = torch.constant.int 4096 - %4903 = torch.prim.ListConstruct %int4_6089, %int4096_6090 : (!torch.int, !torch.int) -> !torch.list - %4904 = torch.aten.view %4889, %4903 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %4905 = torch.aten.mm %4904, %4902 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_6091 = torch.constant.int 4 - %int1_6092 = torch.constant.int 1 - %int1024_6093 = torch.constant.int 1024 - %4906 = torch.prim.ListConstruct %int4_6091, %int1_6092, %int1024_6093 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4907 = torch.aten.view %4905, %4906 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_6094 = torch.constant.int 4 - %int1_6095 = torch.constant.int 1 + %5729 = torch.aten.add.Tensor %5728, %5717, %int1_6085 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_6086 = torch.constant.int 32 + %5730 = torch.aten.mul.Scalar %5729, %int32_6086 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_6087 = torch.constant.int 1 + %5731 = torch.aten.add.Tensor %5730, %5714, %int1_6087 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_6088 = torch.constant.int 5 + %5732 = torch.prims.convert_element_type %5706, %int5_6088 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_6089 = torch.constant.int 32 + %int2_6090 = torch.constant.int 2 + %int8_6091 = torch.constant.int 8 + %int32_6092 = torch.constant.int 32 + %int128_6093 = torch.constant.int 128 + %5733 = torch.prim.ListConstruct %456, %int32_6089, %int2_6090, %int8_6091, %int32_6092, %int128_6093 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5734 = torch.aten.view %5554, %5733 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5734, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_6094 = torch.constant.int 128 + %5735 = torch.prim.ListConstruct %596, %int128_6094 : (!torch.int, !torch.int) -> !torch.list + %5736 = torch.aten.view %5734, %5735 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5736, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %5737 = torch.prim.ListConstruct %5731 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_6095 = torch.constant.bool false + %5738 = torch.aten.index_put %5736, %5737, %5732, %false_6095 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5738, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> %int32_6096 = torch.constant.int 32 - %int128_6097 = torch.constant.int 128 - %4908 = torch.prim.ListConstruct %int4_6094, %int1_6095, %int32_6096, %int128_6097 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4909 = torch.aten.view %4895, %4908 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_6098 = torch.constant.int 4 - %int1_6099 = torch.constant.int 1 - %int8_6100 = torch.constant.int 8 - %int128_6101 = torch.constant.int 128 - %4910 = torch.prim.ListConstruct %int4_6098, %int1_6099, %int8_6100, %int128_6101 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4911 = torch.aten.view %4901, %4910 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_6102 = torch.constant.int 4 - %int1_6103 = torch.constant.int 1 + %int2_6097 = torch.constant.int 2 + %int8_6098 = torch.constant.int 8 + %int32_6099 = torch.constant.int 32 + %int128_6100 = torch.constant.int 128 + %5739 = torch.prim.ListConstruct %456, %int32_6096, %int2_6097, %int8_6098, %int32_6099, %int128_6100 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5740 = torch.aten.view %5738, %5739 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5740, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_6101 = torch.constant.int 2097152 + %5741 = torch.prim.ListConstruct %456, %int2097152_6101 : (!torch.int, !torch.int) -> !torch.list + %5742 = torch.aten.view %5740, %5741 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5742, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_6102 = torch.constant.int 32 + %int2_6103 = torch.constant.int 2 %int8_6104 = torch.constant.int 8 - %int128_6105 = torch.constant.int 128 - %4912 = torch.prim.ListConstruct %int4_6102, %int1_6103, %int8_6104, %int128_6105 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4913 = torch.aten.view %4907, %4912 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_6106 = torch.constant.int 6 - %4914 = torch.prims.convert_element_type %4909, %int6_6106 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %4915 = torch_c.to_builtin_tensor %4914 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %4916 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %4917 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%4915, %4916) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %4918 = torch_c.from_builtin_tensor %4917 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_6107 = torch.constant.int 5 - %4919 = torch.prims.convert_element_type %4918, %int5_6107 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_6108 = torch.constant.int 6 - %4920 = torch.prims.convert_element_type %4911, %int6_6108 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %4921 = torch_c.to_builtin_tensor %4920 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %4922 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %4923 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%4921, %4922) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %4924 = torch_c.from_builtin_tensor %4923 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_6109 = torch.constant.int 5 - %4925 = torch.prims.convert_element_type %4924, %int5_6109 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_6110 = torch.constant.int 32 - %4926 = torch.aten.floor_divide.Scalar %arg2, %int32_6110 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int32_6105 = torch.constant.int 32 + %int128_6106 = torch.constant.int 128 + %5743 = torch.prim.ListConstruct %456, %int32_6102, %int2_6103, %int8_6104, %int32_6105, %int128_6106 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5744 = torch.aten.view %5742, %5743 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5744, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_6107 = torch.constant.int 128 + %5745 = torch.prim.ListConstruct %596, %int128_6107 : (!torch.int, !torch.int) -> !torch.list + %5746 = torch.aten.view %5744, %5745 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5746, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_6108 = torch.constant.none + %5747 = torch.aten.clone %343, %none_6108 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5748 = torch.aten.detach %5747 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5749 = torch.aten.detach %5748 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5750 = torch.aten.detach %5749 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_6109 = torch.constant.int 1 + %int1_6110 = torch.constant.int 1 %int1_6111 = torch.constant.int 1 - %4927 = torch.aten.unsqueeze %4926, %int1_6111 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6112 = torch.constant.int 1 - %false_6113 = torch.constant.bool false - %4928 = torch.aten.gather %arg3, %int1_6112, %4927, %false_6113 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_6114 = torch.constant.int 32 - %4929 = torch.aten.remainder.Scalar %arg2, %int32_6114 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_6115 = torch.constant.int 1 - %4930 = torch.aten.unsqueeze %4929, %int1_6115 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_6116 = torch.constant.none - %4931 = torch.aten.clone %247, %none_6116 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_6117 = torch.constant.int 0 - %4932 = torch.aten.unsqueeze %4931, %int0_6117 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_6118 = torch.constant.int 4 - %int1_6119 = torch.constant.int 1 - %4933 = torch.prim.ListConstruct %int4_6118, %int1_6119 : (!torch.int, !torch.int) -> !torch.list + %5751 = torch.prim.ListConstruct %int1_6109, %int1_6110, %int1_6111 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5752 = torch.aten.view %5750, %5751 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_6112 = torch.constant.int 32 + %5753 = torch.aten.mul.Scalar %5711, %int32_6112 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int24_6113 = torch.constant.int 24 + %int1_6114 = torch.constant.int 1 + %5754 = torch.aten.add.Scalar %5753, %int24_6113, %int1_6114 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_6115 = torch.constant.int 2 + %5755 = torch.aten.mul.Scalar %5754, %int2_6115 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_6116 = torch.constant.int 1 + %5756 = torch.aten.add.Tensor %5755, %5752, %int1_6116 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_6117 = torch.constant.int 8 + %5757 = torch.aten.mul.Scalar %5756, %int8_6117 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_6118 = torch.constant.int 1 + %5758 = torch.aten.add.Tensor %5757, %5717, %int1_6118 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_6119 = torch.constant.int 32 + %5759 = torch.aten.mul.Scalar %5758, %int32_6119 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_6120 = torch.constant.int 1 - %int1_6121 = torch.constant.int 1 - %4934 = torch.prim.ListConstruct %int1_6120, %int1_6121 : (!torch.int, !torch.int) -> !torch.list - %int4_6122 = torch.constant.int 4 - %int0_6123 = torch.constant.int 0 - %cpu_6124 = torch.constant.device "cpu" - %false_6125 = torch.constant.bool false - %4935 = torch.aten.empty_strided %4933, %4934, %int4_6122, %int0_6123, %cpu_6124, %false_6125 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int22 = torch.constant.int 22 - %4936 = torch.aten.fill.Scalar %4935, %int22 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_6126 = torch.constant.int 4 - %int1_6127 = torch.constant.int 1 - %4937 = torch.prim.ListConstruct %int4_6126, %int1_6127 : (!torch.int, !torch.int) -> !torch.list - %4938 = torch.aten.repeat %4932, %4937 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_6128 = torch.constant.int 32 - %4939 = torch.aten.mul.Scalar %4928, %int32_6128 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6129 = torch.constant.int 1 - %4940 = torch.aten.add.Tensor %4939, %4936, %int1_6129 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_6130 = torch.constant.int 2 - %4941 = torch.aten.mul.Scalar %4940, %int2_6130 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6131 = torch.constant.int 1 - %4942 = torch.aten.add.Tensor %4941, %4938, %int1_6131 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %5760 = torch.aten.add.Tensor %5759, %5714, %int1_6120 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_6121 = torch.constant.int 5 + %5761 = torch.prims.convert_element_type %5686, %int5_6121 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %5762 = torch.prim.ListConstruct %5760 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_6122 = torch.constant.bool false + %5763 = torch.aten.index_put %5746, %5762, %5761, %false_6122 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5763, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_6123 = torch.constant.int 32 + %int2_6124 = torch.constant.int 2 + %int8_6125 = torch.constant.int 8 + %int32_6126 = torch.constant.int 32 + %int128_6127 = torch.constant.int 128 + %5764 = torch.prim.ListConstruct %456, %int32_6123, %int2_6124, %int8_6125, %int32_6126, %int128_6127 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5765 = torch.aten.view %5763, %5764 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5765, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_6128 = torch.constant.int 2097152 + %5766 = torch.prim.ListConstruct %456, %int2097152_6128 : (!torch.int, !torch.int) -> !torch.list + %5767 = torch.aten.view %5765, %5766 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5767, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_6129 = torch.constant.none + %5768 = torch.aten.clone %344, %none_6129 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5769 = torch.aten.detach %5768 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5770 = torch.aten.detach %5769 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5771 = torch.aten.detach %5770 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_6130 = torch.constant.none + %5772 = torch.aten.clone %345, %none_6130 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5773 = torch.aten.detach %5772 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5774 = torch.aten.detach %5773 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5775 = torch.aten.detach %5774 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_6131 = torch.constant.none + %5776 = torch.aten.clone %346, %none_6131 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5777 = torch.aten.detach %5776 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5778 = torch.aten.detach %5777 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5779 = torch.aten.detach %5778 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %int32_6132 = torch.constant.int 32 - %4943 = torch.aten.mul.Scalar %4942, %int32_6132 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6133 = torch.constant.int 1 - %4944 = torch.aten.add.Tensor %4943, %4930, %int1_6133 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_6134 = torch.constant.int 32 - %int2_6135 = torch.constant.int 2 - %int32_6136 = torch.constant.int 32 - %int8_6137 = torch.constant.int 8 - %int128_6138 = torch.constant.int 128 - %4945 = torch.prim.ListConstruct %437, %int32_6134, %int2_6135, %int32_6136, %int8_6137, %int128_6138 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4946 = torch.aten.view %4782, %4945 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4946, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_6139 = torch.constant.int 32 - %4947 = torch.aten.mul.int %437, %int32_6139 : !torch.int, !torch.int -> !torch.int - %int2_6140 = torch.constant.int 2 - %4948 = torch.aten.mul.int %4947, %int2_6140 : !torch.int, !torch.int -> !torch.int - %int32_6141 = torch.constant.int 32 - %4949 = torch.aten.mul.int %4948, %int32_6141 : !torch.int, !torch.int -> !torch.int - %int8_6142 = torch.constant.int 8 - %int128_6143 = torch.constant.int 128 - %4950 = torch.prim.ListConstruct %4949, %int8_6142, %int128_6143 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4951 = torch.aten.view %4946, %4950 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4951, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %4952 = torch.prim.ListConstruct %4944 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_6144 = torch.constant.bool false - %4953 = torch.aten.index_put %4951, %4952, %4925, %false_6144 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4953, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_6145 = torch.constant.int 32 - %int2_6146 = torch.constant.int 2 - %int32_6147 = torch.constant.int 32 - %int8_6148 = torch.constant.int 8 - %int128_6149 = torch.constant.int 128 - %4954 = torch.prim.ListConstruct %437, %int32_6145, %int2_6146, %int32_6147, %int8_6148, %int128_6149 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4955 = torch.aten.view %4953, %4954 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4955, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6150 = torch.constant.int 2097152 - %4956 = torch.prim.ListConstruct %437, %int2097152_6150 : (!torch.int, !torch.int) -> !torch.list - %4957 = torch.aten.view %4955, %4956 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4957, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_6151 = torch.constant.int 32 - %int2_6152 = torch.constant.int 2 - %int32_6153 = torch.constant.int 32 - %int8_6154 = torch.constant.int 8 - %int128_6155 = torch.constant.int 128 - %4958 = torch.prim.ListConstruct %437, %int32_6151, %int2_6152, %int32_6153, %int8_6154, %int128_6155 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4959 = torch.aten.view %4957, %4958 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4959, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_6156 = torch.constant.int 8 + %int2_6133 = torch.constant.int 2 + %int8_6134 = torch.constant.int 8 + %int32_6135 = torch.constant.int 32 + %int128_6136 = torch.constant.int 128 + %5780 = torch.prim.ListConstruct %456, %int32_6132, %int2_6133, %int8_6134, %int32_6135, %int128_6136 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5781 = torch.aten.view %5767, %5780 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5781, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %5782 = torch_c.to_builtin_tensor %5781 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %5783 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_6137 = tensor.cast %5783 : tensor<4x?xi64> to tensor + %5784 = torch_c.to_builtin_tensor %5771 : !torch.vtensor<[],si64> -> tensor + %5785 = torch_c.to_builtin_tensor %5775 : !torch.vtensor<[],si64> -> tensor + %5786 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%5782, %cast_6137, %5784, %5785) : (tensor, tensor, tensor, tensor) -> tensor + %cast_6138 = tensor.cast %5786 : tensor to tensor<4x?x8x32x128xf16> + %5787 = torch_c.from_builtin_tensor %cast_6138 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %5787, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %5788 = torch_c.to_builtin_tensor %5781 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %5789 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_6139 = tensor.cast %5789 : tensor<4x?xi64> to tensor + %5790 = torch_c.to_builtin_tensor %5771 : !torch.vtensor<[],si64> -> tensor + %5791 = torch_c.to_builtin_tensor %5779 : !torch.vtensor<[],si64> -> tensor + %5792 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%5788, %cast_6139, %5790, %5791) : (tensor, tensor, tensor, tensor) -> tensor + %cast_6140 = tensor.cast %5792 : tensor to tensor<4x?x8x32x128xf16> + %5793 = torch_c.from_builtin_tensor %cast_6140 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %5793, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_6141 = torch.constant.int 2 + %int3_6142 = torch.constant.int 3 + %5794 = torch.aten.transpose.int %5787, %int2_6141, %int3_6142 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5794, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_6143 = torch.constant.int 0 + %5795 = torch.aten.clone %5794, %int0_6143 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5795, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_6144 = torch.constant.int 4 + %int8_6145 = torch.constant.int 8 + %int128_6146 = torch.constant.int 128 + %5796 = torch.prim.ListConstruct %int4_6144, %457, %int8_6145, %int128_6146 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5797 = torch.aten._unsafe_view %5795, %5796 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5797, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_6147 = torch.constant.int 2 + %int3_6148 = torch.constant.int 3 + %5798 = torch.aten.transpose.int %5793, %int2_6147, %int3_6148 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5798, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_6149 = torch.constant.int 0 + %5799 = torch.aten.clone %5798, %int0_6149 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %5799, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_6150 = torch.constant.int 4 + %int8_6151 = torch.constant.int 8 + %int128_6152 = torch.constant.int 128 + %5800 = torch.prim.ListConstruct %int4_6150, %457, %int8_6151, %int128_6152 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5801 = torch.aten._unsafe_view %5799, %5800 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %5801, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_6153 = torch.constant.int -2 + %5802 = torch.aten.unsqueeze %5797, %int-2_6153 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5802, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_6154 = torch.constant.int 4 + %int8_6155 = torch.constant.int 8 + %int4_6156 = torch.constant.int 4 %int128_6157 = torch.constant.int 128 - %4960 = torch.prim.ListConstruct %4949, %int8_6156, %int128_6157 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %4961 = torch.aten.view %4959, %4960 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4961, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_6158 = torch.constant.int 32 - %4962 = torch.aten.floor_divide.Scalar %arg2, %int32_6158 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_6159 = torch.constant.int 1 - %4963 = torch.aten.unsqueeze %4962, %int1_6159 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6160 = torch.constant.int 1 - %false_6161 = torch.constant.bool false - %4964 = torch.aten.gather %arg3, %int1_6160, %4963, %false_6161 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_6162 = torch.constant.int 32 - %4965 = torch.aten.remainder.Scalar %arg2, %int32_6162 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_6163 = torch.constant.int 1 - %4966 = torch.aten.unsqueeze %4965, %int1_6163 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_6164 = torch.constant.none - %4967 = torch.aten.clone %248, %none_6164 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_6165 = torch.constant.int 0 - %4968 = torch.aten.unsqueeze %4967, %int0_6165 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %5803 = torch.prim.ListConstruct %int4_6154, %457, %int8_6155, %int4_6156, %int128_6157 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_6158 = torch.constant.bool false + %5804 = torch.aten.expand %5802, %5803, %false_6158 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5804, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_6159 = torch.constant.int 0 + %5805 = torch.aten.clone %5804, %int0_6159 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5805, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_6160 = torch.constant.int 4 + %int32_6161 = torch.constant.int 32 + %int128_6162 = torch.constant.int 128 + %5806 = torch.prim.ListConstruct %int4_6160, %457, %int32_6161, %int128_6162 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5807 = torch.aten._unsafe_view %5805, %5806 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5807, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_6163 = torch.constant.int -2 + %5808 = torch.aten.unsqueeze %5801, %int-2_6163 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %5808, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_6164 = torch.constant.int 4 + %int8_6165 = torch.constant.int 8 %int4_6166 = torch.constant.int 4 - %int1_6167 = torch.constant.int 1 - %4969 = torch.prim.ListConstruct %int4_6166, %int1_6167 : (!torch.int, !torch.int) -> !torch.list - %int1_6168 = torch.constant.int 1 - %int1_6169 = torch.constant.int 1 - %4970 = torch.prim.ListConstruct %int1_6168, %int1_6169 : (!torch.int, !torch.int) -> !torch.list + %int128_6167 = torch.constant.int 128 + %5809 = torch.prim.ListConstruct %int4_6164, %457, %int8_6165, %int4_6166, %int128_6167 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_6168 = torch.constant.bool false + %5810 = torch.aten.expand %5808, %5809, %false_6168 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5810, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_6169 = torch.constant.int 0 + %5811 = torch.aten.clone %5810, %int0_6169 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %5811, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_6170 = torch.constant.int 4 - %int0_6171 = torch.constant.int 0 - %cpu_6172 = torch.constant.device "cpu" - %false_6173 = torch.constant.bool false - %4971 = torch.aten.empty_strided %4969, %4970, %int4_6170, %int0_6171, %cpu_6172, %false_6173 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int22_6174 = torch.constant.int 22 - %4972 = torch.aten.fill.Scalar %4971, %int22_6174 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_6175 = torch.constant.int 4 - %int1_6176 = torch.constant.int 1 - %4973 = torch.prim.ListConstruct %int4_6175, %int1_6176 : (!torch.int, !torch.int) -> !torch.list - %4974 = torch.aten.repeat %4968, %4973 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_6177 = torch.constant.int 32 - %4975 = torch.aten.mul.Scalar %4964, %int32_6177 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6178 = torch.constant.int 1 - %4976 = torch.aten.add.Tensor %4975, %4972, %int1_6178 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_6179 = torch.constant.int 2 - %4977 = torch.aten.mul.Scalar %4976, %int2_6179 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6180 = torch.constant.int 1 - %4978 = torch.aten.add.Tensor %4977, %4974, %int1_6180 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_6181 = torch.constant.int 32 - %4979 = torch.aten.mul.Scalar %4978, %int32_6181 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int32_6171 = torch.constant.int 32 + %int128_6172 = torch.constant.int 128 + %5812 = torch.prim.ListConstruct %int4_6170, %457, %int32_6171, %int128_6172 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5813 = torch.aten._unsafe_view %5811, %5812 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %5813, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_6173 = torch.constant.int 1 + %int2_6174 = torch.constant.int 2 + %5814 = torch.aten.transpose.int %5696, %int1_6173, %int2_6174 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_6175 = torch.constant.int 1 + %int2_6176 = torch.constant.int 2 + %5815 = torch.aten.transpose.int %5807, %int1_6175, %int2_6176 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5815, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_6177 = torch.constant.int 1 + %int2_6178 = torch.constant.int 2 + %5816 = torch.aten.transpose.int %5813, %int1_6177, %int2_6178 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %5816, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_6179 = torch.constant.float 0.000000e+00 + %false_6180 = torch.constant.bool false + %none_6181 = torch.constant.none + %5817:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5814, %5815, %5816, %float0.000000e00_6179, %false_6180, %470, %none_6181) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) %int1_6182 = torch.constant.int 1 - %4980 = torch.aten.add.Tensor %4979, %4966, %int1_6182 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %4981 = torch.prim.ListConstruct %4980 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_6183 = torch.constant.bool false - %4982 = torch.aten.index_put %4961, %4981, %4913, %false_6183 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %4982, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_6184 = torch.constant.int 32 - %int2_6185 = torch.constant.int 2 - %int32_6186 = torch.constant.int 32 - %int8_6187 = torch.constant.int 8 - %int128_6188 = torch.constant.int 128 - %4983 = torch.prim.ListConstruct %437, %int32_6184, %int2_6185, %int32_6186, %int8_6187, %int128_6188 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4984 = torch.aten.view %4982, %4983 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4984, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6189 = torch.constant.int 2097152 - %4985 = torch.prim.ListConstruct %437, %int2097152_6189 : (!torch.int, !torch.int) -> !torch.list - %4986 = torch.aten.view %4984, %4985 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %4986, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int2_6183 = torch.constant.int 2 + %5818 = torch.aten.transpose.int %5817#0, %int1_6182, %int2_6183 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_6184 = torch.constant.int 4 + %int1_6185 = torch.constant.int 1 + %int4096_6186 = torch.constant.int 4096 + %5819 = torch.prim.ListConstruct %int4_6184, %int1_6185, %int4096_6186 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5820 = torch.aten.view %5818, %5819 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_6187 = torch.constant.int -2 + %int-1_6188 = torch.constant.int -1 + %5821 = torch.aten.transpose.int %347, %int-2_6187, %int-1_6188 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_6189 = torch.constant.int 5 + %5822 = torch.prims.convert_element_type %5821, %int5_6189 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> %int4_6190 = torch.constant.int 4 - %4987 = torch.prim.ListConstruct %int4_6190, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_6191 = torch.constant.int 1 - %4988 = torch.prim.ListConstruct %358, %int1_6191 : (!torch.int, !torch.int) -> !torch.list + %int4096_6191 = torch.constant.int 4096 + %5823 = torch.prim.ListConstruct %int4_6190, %int4096_6191 : (!torch.int, !torch.int) -> !torch.list + %5824 = torch.aten.view %5820, %5823 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5825 = torch.aten.mm %5824, %5822 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_6192 = torch.constant.int 4 - %int0_6193 = torch.constant.int 0 - %cpu_6194 = torch.constant.device "cpu" - %false_6195 = torch.constant.bool false - %4989 = torch.aten.empty_strided %4987, %4988, %int4_6192, %int0_6193, %cpu_6194, %false_6195 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4989, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int22_6196 = torch.constant.int 22 - %4990 = torch.aten.fill.Scalar %4989, %int22_6196 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4990, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_6197 = torch.constant.int 32 - %4991 = torch.aten.mul.Scalar %arg3, %int32_6197 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4991, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_6198 = torch.constant.int 1 - %4992 = torch.aten.add.Tensor %4991, %4990, %int1_6198 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %4992, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_6199 = torch.constant.int 4 - %4993 = torch.aten.mul.int %int4_6199, %358 : !torch.int, !torch.int -> !torch.int - %4994 = torch.prim.ListConstruct %4993 : (!torch.int) -> !torch.list - %4995 = torch.aten.view %4992, %4994 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %4995, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_6200 = torch.constant.int 32 - %int2_6201 = torch.constant.int 2 - %int32_6202 = torch.constant.int 32 - %int8_6203 = torch.constant.int 8 - %int128_6204 = torch.constant.int 128 - %4996 = torch.prim.ListConstruct %437, %int32_6200, %int2_6201, %int32_6202, %int8_6203, %int128_6204 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %4997 = torch.aten.view %4986, %4996 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %4997, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_6205 = torch.constant.int 32 - %4998 = torch.aten.mul.int %437, %int32_6205 : !torch.int, !torch.int -> !torch.int - %int2_6206 = torch.constant.int 2 - %int32_6207 = torch.constant.int 32 - %int8_6208 = torch.constant.int 8 - %int128_6209 = torch.constant.int 128 - %4999 = torch.prim.ListConstruct %4998, %int2_6206, %int32_6207, %int8_6208, %int128_6209 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5000 = torch.aten.view %4997, %4999 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %5000, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_6210 = torch.constant.int 0 - %5001 = torch.aten.index_select %5000, %int0_6210, %4995 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %5001, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_6211 = torch.constant.int 4 - %int2_6212 = torch.constant.int 2 - %int32_6213 = torch.constant.int 32 - %int8_6214 = torch.constant.int 8 - %int128_6215 = torch.constant.int 128 - %5002 = torch.prim.ListConstruct %int4_6211, %358, %int2_6212, %int32_6213, %int8_6214, %int128_6215 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5003 = torch.aten.view %5001, %5002 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5003, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_6216 = torch.constant.int 0 - %int0_6217 = torch.constant.int 0 - %int9223372036854775807_6218 = torch.constant.int 9223372036854775807 + %int1_6193 = torch.constant.int 1 + %int4096_6194 = torch.constant.int 4096 + %5826 = torch.prim.ListConstruct %int4_6192, %int1_6193, %int4096_6194 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5827 = torch.aten.view %5825, %5826 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_6195 = torch.constant.int 1 + %5828 = torch.aten.add.Tensor %5649, %5827, %int1_6195 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_6196 = torch.constant.int 6 + %5829 = torch.prims.convert_element_type %5828, %int6_6196 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_6197 = torch.constant.int 2 + %5830 = torch.aten.pow.Tensor_Scalar %5829, %int2_6197 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_6198 = torch.constant.int -1 + %5831 = torch.prim.ListConstruct %int-1_6198 : (!torch.int) -> !torch.list + %true_6199 = torch.constant.bool true + %none_6200 = torch.constant.none + %5832 = torch.aten.mean.dim %5830, %5831, %true_6199, %none_6200 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_6201 = torch.constant.float 9.9999997473787516E-6 + %int1_6202 = torch.constant.int 1 + %5833 = torch.aten.add.Scalar %5832, %float9.999990e-06_6201, %int1_6202 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %5834 = torch.aten.rsqrt %5833 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %5835 = torch.aten.mul.Tensor %5829, %5834 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_6203 = torch.constant.int 5 + %5836 = torch.prims.convert_element_type %5835, %int5_6203 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %5837 = torch.aten.mul.Tensor %348, %5836 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_6204 = torch.constant.int 5 + %5838 = torch.prims.convert_element_type %5837, %int5_6204 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_6205 = torch.constant.int -2 + %int-1_6206 = torch.constant.int -1 + %5839 = torch.aten.transpose.int %349, %int-2_6205, %int-1_6206 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_6207 = torch.constant.int 5 + %5840 = torch.prims.convert_element_type %5839, %int5_6207 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_6208 = torch.constant.int 4 + %int4096_6209 = torch.constant.int 4096 + %5841 = torch.prim.ListConstruct %int4_6208, %int4096_6209 : (!torch.int, !torch.int) -> !torch.list + %5842 = torch.aten.view %5838, %5841 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5843 = torch.aten.mm %5842, %5840 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_6210 = torch.constant.int 4 + %int1_6211 = torch.constant.int 1 + %int14336_6212 = torch.constant.int 14336 + %5844 = torch.prim.ListConstruct %int4_6210, %int1_6211, %int14336_6212 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5845 = torch.aten.view %5843, %5844 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %5846 = torch.aten.silu %5845 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_6213 = torch.constant.int -2 + %int-1_6214 = torch.constant.int -1 + %5847 = torch.aten.transpose.int %350, %int-2_6213, %int-1_6214 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_6215 = torch.constant.int 5 + %5848 = torch.prims.convert_element_type %5847, %int5_6215 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_6216 = torch.constant.int 4 + %int4096_6217 = torch.constant.int 4096 + %5849 = torch.prim.ListConstruct %int4_6216, %int4096_6217 : (!torch.int, !torch.int) -> !torch.list + %5850 = torch.aten.view %5838, %5849 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5851 = torch.aten.mm %5850, %5848 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_6218 = torch.constant.int 4 %int1_6219 = torch.constant.int 1 - %5004 = torch.aten.slice.Tensor %5003, %int0_6216, %int0_6217, %int9223372036854775807_6218, %int1_6219 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5004, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_6220 = torch.constant.int 1 - %int0_6221 = torch.constant.int 0 - %int9223372036854775807_6222 = torch.constant.int 9223372036854775807 - %int1_6223 = torch.constant.int 1 - %5005 = torch.aten.slice.Tensor %5004, %int1_6220, %int0_6221, %int9223372036854775807_6222, %int1_6223 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5005, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_6224 = torch.constant.int 2 - %int0_6225 = torch.constant.int 0 - %5006 = torch.aten.select.int %5005, %int2_6224, %int0_6225 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5006, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_6226 = torch.constant.int 32 - %5007 = torch.aten.mul.int %358, %int32_6226 : !torch.int, !torch.int -> !torch.int - %int2_6227 = torch.constant.int 2 - %int0_6228 = torch.constant.int 0 + %int14336_6220 = torch.constant.int 14336 + %5852 = torch.prim.ListConstruct %int4_6218, %int1_6219, %int14336_6220 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5853 = torch.aten.view %5851, %5852 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %5854 = torch.aten.mul.Tensor %5846, %5853 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_6221 = torch.constant.int -2 + %int-1_6222 = torch.constant.int -1 + %5855 = torch.aten.transpose.int %351, %int-2_6221, %int-1_6222 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_6223 = torch.constant.int 5 + %5856 = torch.prims.convert_element_type %5855, %int5_6223 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_6224 = torch.constant.int 4 + %int14336_6225 = torch.constant.int 14336 + %5857 = torch.prim.ListConstruct %int4_6224, %int14336_6225 : (!torch.int, !torch.int) -> !torch.list + %5858 = torch.aten.view %5854, %5857 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %5859 = torch.aten.mm %5858, %5856 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_6226 = torch.constant.int 4 + %int1_6227 = torch.constant.int 1 + %int4096_6228 = torch.constant.int 4096 + %5860 = torch.prim.ListConstruct %int4_6226, %int1_6227, %int4096_6228 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5861 = torch.aten.view %5859, %5860 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_6229 = torch.constant.int 1 - %5008 = torch.aten.slice.Tensor %5006, %int2_6227, %int0_6228, %5007, %int1_6229 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5008, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_6230 = torch.constant.int 0 - %5009 = torch.aten.clone %5008, %int0_6230 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5009, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_6231 = torch.constant.int 1 - %5010 = torch.aten.size.int %5005, %int1_6231 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_6232 = torch.constant.int 32 - %5011 = torch.aten.mul.int %5010, %int32_6232 : !torch.int, !torch.int -> !torch.int - %int4_6233 = torch.constant.int 4 - %int8_6234 = torch.constant.int 8 - %int128_6235 = torch.constant.int 128 - %5012 = torch.prim.ListConstruct %int4_6233, %5011, %int8_6234, %int128_6235 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5013 = torch.aten._unsafe_view %5009, %5012 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5013, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_6236 = torch.constant.int 0 - %int0_6237 = torch.constant.int 0 - %int9223372036854775807_6238 = torch.constant.int 9223372036854775807 - %int1_6239 = torch.constant.int 1 - %5014 = torch.aten.slice.Tensor %5013, %int0_6236, %int0_6237, %int9223372036854775807_6238, %int1_6239 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5014, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_6240 = torch.constant.int 0 - %int0_6241 = torch.constant.int 0 - %int9223372036854775807_6242 = torch.constant.int 9223372036854775807 - %int1_6243 = torch.constant.int 1 - %5015 = torch.aten.slice.Tensor %5003, %int0_6240, %int0_6241, %int9223372036854775807_6242, %int1_6243 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5015, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_6244 = torch.constant.int 1 - %int0_6245 = torch.constant.int 0 - %int9223372036854775807_6246 = torch.constant.int 9223372036854775807 - %int1_6247 = torch.constant.int 1 - %5016 = torch.aten.slice.Tensor %5015, %int1_6244, %int0_6245, %int9223372036854775807_6246, %int1_6247 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5016, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_6248 = torch.constant.int 2 - %int1_6249 = torch.constant.int 1 - %5017 = torch.aten.select.int %5016, %int2_6248, %int1_6249 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5017, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_6250 = torch.constant.int 2 - %int0_6251 = torch.constant.int 0 - %int1_6252 = torch.constant.int 1 - %5018 = torch.aten.slice.Tensor %5017, %int2_6250, %int0_6251, %5007, %int1_6252 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5018, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_6253 = torch.constant.int 0 - %5019 = torch.aten.clone %5018, %int0_6253 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5019, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_6254 = torch.constant.int 1 - %5020 = torch.aten.size.int %5016, %int1_6254 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_6255 = torch.constant.int 32 - %5021 = torch.aten.mul.int %5020, %int32_6255 : !torch.int, !torch.int -> !torch.int - %int4_6256 = torch.constant.int 4 - %int8_6257 = torch.constant.int 8 - %int128_6258 = torch.constant.int 128 - %5022 = torch.prim.ListConstruct %int4_6256, %5021, %int8_6257, %int128_6258 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5023 = torch.aten._unsafe_view %5019, %5022 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5023, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_6259 = torch.constant.int 0 - %int0_6260 = torch.constant.int 0 - %int9223372036854775807_6261 = torch.constant.int 9223372036854775807 - %int1_6262 = torch.constant.int 1 - %5024 = torch.aten.slice.Tensor %5023, %int0_6259, %int0_6260, %int9223372036854775807_6261, %int1_6262 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5024, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_6263 = torch.constant.int -2 - %5025 = torch.aten.unsqueeze %5014, %int-2_6263 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5025, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %5862 = torch.aten.add.Tensor %5828, %5861, %int1_6229 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_6230 = torch.constant.int 6 + %5863 = torch.prims.convert_element_type %5862, %int6_6230 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_6231 = torch.constant.int 2 + %5864 = torch.aten.pow.Tensor_Scalar %5863, %int2_6231 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_6232 = torch.constant.int -1 + %5865 = torch.prim.ListConstruct %int-1_6232 : (!torch.int) -> !torch.list + %true_6233 = torch.constant.bool true + %none_6234 = torch.constant.none + %5866 = torch.aten.mean.dim %5864, %5865, %true_6233, %none_6234 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_6235 = torch.constant.float 9.9999997473787516E-6 + %int1_6236 = torch.constant.int 1 + %5867 = torch.aten.add.Scalar %5866, %float9.999990e-06_6235, %int1_6236 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %5868 = torch.aten.rsqrt %5867 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %5869 = torch.aten.mul.Tensor %5863, %5868 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_6237 = torch.constant.int 5 + %5870 = torch.prims.convert_element_type %5869, %int5_6237 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %5871 = torch.aten.mul.Tensor %352, %5870 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_6238 = torch.constant.int 5 + %5872 = torch.prims.convert_element_type %5871, %int5_6238 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_6239 = torch.constant.int -2 + %int-1_6240 = torch.constant.int -1 + %5873 = torch.aten.transpose.int %353, %int-2_6239, %int-1_6240 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_6241 = torch.constant.int 5 + %5874 = torch.prims.convert_element_type %5873, %int5_6241 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_6242 = torch.constant.int 4 + %int4096_6243 = torch.constant.int 4096 + %5875 = torch.prim.ListConstruct %int4_6242, %int4096_6243 : (!torch.int, !torch.int) -> !torch.list + %5876 = torch.aten.view %5872, %5875 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5877 = torch.aten.mm %5876, %5874 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_6244 = torch.constant.int 4 + %int1_6245 = torch.constant.int 1 + %int4096_6246 = torch.constant.int 4096 + %5878 = torch.prim.ListConstruct %int4_6244, %int1_6245, %int4096_6246 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5879 = torch.aten.view %5877, %5878 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_6247 = torch.constant.int -2 + %int-1_6248 = torch.constant.int -1 + %5880 = torch.aten.transpose.int %354, %int-2_6247, %int-1_6248 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6249 = torch.constant.int 5 + %5881 = torch.prims.convert_element_type %5880, %int5_6249 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_6250 = torch.constant.int 4 + %int4096_6251 = torch.constant.int 4096 + %5882 = torch.prim.ListConstruct %int4_6250, %int4096_6251 : (!torch.int, !torch.int) -> !torch.list + %5883 = torch.aten.view %5872, %5882 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5884 = torch.aten.mm %5883, %5881 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_6252 = torch.constant.int 4 + %int1_6253 = torch.constant.int 1 + %int1024_6254 = torch.constant.int 1024 + %5885 = torch.prim.ListConstruct %int4_6252, %int1_6253, %int1024_6254 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5886 = torch.aten.view %5884, %5885 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_6255 = torch.constant.int -2 + %int-1_6256 = torch.constant.int -1 + %5887 = torch.aten.transpose.int %355, %int-2_6255, %int-1_6256 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6257 = torch.constant.int 5 + %5888 = torch.prims.convert_element_type %5887, %int5_6257 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_6258 = torch.constant.int 4 + %int4096_6259 = torch.constant.int 4096 + %5889 = torch.prim.ListConstruct %int4_6258, %int4096_6259 : (!torch.int, !torch.int) -> !torch.list + %5890 = torch.aten.view %5872, %5889 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %5891 = torch.aten.mm %5890, %5888 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_6260 = torch.constant.int 4 + %int1_6261 = torch.constant.int 1 + %int1024_6262 = torch.constant.int 1024 + %5892 = torch.prim.ListConstruct %int4_6260, %int1_6261, %int1024_6262 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5893 = torch.aten.view %5891, %5892 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_6263 = torch.constant.int 4 %int1_6264 = torch.constant.int 1 - %5026 = torch.aten.size.int %5013, %int1_6264 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_6265 = torch.constant.int 4 - %int8_6266 = torch.constant.int 8 + %int32_6265 = torch.constant.int 32 + %int128_6266 = torch.constant.int 128 + %5894 = torch.prim.ListConstruct %int4_6263, %int1_6264, %int32_6265, %int128_6266 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5895 = torch.aten.view %5879, %5894 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> %int4_6267 = torch.constant.int 4 - %int128_6268 = torch.constant.int 128 - %5027 = torch.prim.ListConstruct %int4_6265, %5026, %int8_6266, %int4_6267, %int128_6268 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_6269 = torch.constant.bool false - %5028 = torch.aten.expand %5025, %5027, %false_6269 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5028, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_6270 = torch.constant.int 0 - %5029 = torch.aten.clone %5028, %int0_6270 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5029, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int1_6268 = torch.constant.int 1 + %int8_6269 = torch.constant.int 8 + %int128_6270 = torch.constant.int 128 + %5896 = torch.prim.ListConstruct %int4_6267, %int1_6268, %int8_6269, %int128_6270 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5897 = torch.aten.view %5886, %5896 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> %int4_6271 = torch.constant.int 4 - %int32_6272 = torch.constant.int 32 - %int128_6273 = torch.constant.int 128 - %5030 = torch.prim.ListConstruct %int4_6271, %5026, %int32_6272, %int128_6273 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5031 = torch.aten._unsafe_view %5029, %5030 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5031, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_6274 = torch.constant.int -2 - %5032 = torch.aten.unsqueeze %5024, %int-2_6274 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5032, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int1_6272 = torch.constant.int 1 + %int8_6273 = torch.constant.int 8 + %int128_6274 = torch.constant.int 128 + %5898 = torch.prim.ListConstruct %int4_6271, %int1_6272, %int8_6273, %int128_6274 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5899 = torch.aten.view %5893, %5898 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> %int1_6275 = torch.constant.int 1 - %5033 = torch.aten.size.int %5023, %int1_6275 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_6276 = torch.constant.int 4 - %int8_6277 = torch.constant.int 8 - %int4_6278 = torch.constant.int 4 - %int128_6279 = torch.constant.int 128 - %5034 = torch.prim.ListConstruct %int4_6276, %5033, %int8_6277, %int4_6278, %int128_6279 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_6280 = torch.constant.bool false - %5035 = torch.aten.expand %5032, %5034, %false_6280 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5035, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_6281 = torch.constant.int 0 - %5036 = torch.aten.clone %5035, %int0_6281 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5036, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_6282 = torch.constant.int 4 - %int32_6283 = torch.constant.int 32 - %int128_6284 = torch.constant.int 128 - %5037 = torch.prim.ListConstruct %int4_6282, %5033, %int32_6283, %int128_6284 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5038 = torch.aten._unsafe_view %5036, %5037 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5038, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_6285 = torch.constant.int 1 - %int2_6286 = torch.constant.int 2 - %5039 = torch.aten.transpose.int %4919, %int1_6285, %int2_6286 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int2_6276 = torch.constant.int 2 + %5900 = torch.aten.transpose.int %5895, %int1_6275, %int2_6276 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %5901 = torch.aten.mul.Tensor %5900, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_6277 = torch.constant.int 3 + %int0_6278 = torch.constant.int 0 + %int64_6279 = torch.constant.int 64 + %int1_6280 = torch.constant.int 1 + %5902 = torch.aten.slice.Tensor %5900, %int3_6277, %int0_6278, %int64_6279, %int1_6280 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_6281 = torch.constant.int 3 + %int64_6282 = torch.constant.int 64 + %int9223372036854775807_6283 = torch.constant.int 9223372036854775807 + %int1_6284 = torch.constant.int 1 + %5903 = torch.aten.slice.Tensor %5900, %int3_6281, %int64_6282, %int9223372036854775807_6283, %int1_6284 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %5904 = torch.aten.neg %5903 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %5905 = torch.prim.ListConstruct %5904, %5902 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_6285 = torch.constant.int -1 + %5906 = torch.aten.cat %5905, %int-1_6285 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %5907 = torch.aten.mul.Tensor %5906, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_6286 = torch.constant.int 1 + %5908 = torch.aten.add.Tensor %5901, %5907, %int1_6286 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_6287 = torch.constant.int 1 %int2_6288 = torch.constant.int 2 - %5040 = torch.aten.transpose.int %5031, %int1_6287, %int2_6288 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5040, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %5909 = torch.aten.transpose.int %5908, %int1_6287, %int2_6288 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> %int1_6289 = torch.constant.int 1 %int2_6290 = torch.constant.int 2 - %5041 = torch.aten.transpose.int %5038, %int1_6289, %int2_6290 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5041, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_6291 = torch.constant.float 0.000000e+00 - %false_6292 = torch.constant.bool false - %none_6293 = torch.constant.none - %5042:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5039, %5040, %5041, %float0.000000e00_6291, %false_6292, %368, %none_6293) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %5910 = torch.aten.transpose.int %5897, %int1_6289, %int2_6290 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %5911 = torch.aten.mul.Tensor %5910, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_6291 = torch.constant.int 3 + %int0_6292 = torch.constant.int 0 + %int64_6293 = torch.constant.int 64 %int1_6294 = torch.constant.int 1 - %int2_6295 = torch.constant.int 2 - %5043 = torch.aten.transpose.int %5042#0, %int1_6294, %int2_6295 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_6296 = torch.constant.int 4 - %int1_6297 = torch.constant.int 1 - %int4096_6298 = torch.constant.int 4096 - %5044 = torch.prim.ListConstruct %int4_6296, %int1_6297, %int4096_6298 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5045 = torch.aten.view %5043, %5044 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_6299 = torch.constant.int -2 - %int-1_6300 = torch.constant.int -1 - %5046 = torch.aten.transpose.int %249, %int-2_6299, %int-1_6300 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6301 = torch.constant.int 4 - %int4096_6302 = torch.constant.int 4096 - %5047 = torch.prim.ListConstruct %int4_6301, %int4096_6302 : (!torch.int, !torch.int) -> !torch.list - %5048 = torch.aten.view %5045, %5047 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5049 = torch.aten.mm %5048, %5046 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_6303 = torch.constant.int 4 + %5912 = torch.aten.slice.Tensor %5910, %int3_6291, %int0_6292, %int64_6293, %int1_6294 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_6295 = torch.constant.int 3 + %int64_6296 = torch.constant.int 64 + %int9223372036854775807_6297 = torch.constant.int 9223372036854775807 + %int1_6298 = torch.constant.int 1 + %5913 = torch.aten.slice.Tensor %5910, %int3_6295, %int64_6296, %int9223372036854775807_6297, %int1_6298 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %5914 = torch.aten.neg %5913 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %5915 = torch.prim.ListConstruct %5914, %5912 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_6299 = torch.constant.int -1 + %5916 = torch.aten.cat %5915, %int-1_6299 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %5917 = torch.aten.mul.Tensor %5916, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_6300 = torch.constant.int 1 + %5918 = torch.aten.add.Tensor %5911, %5917, %int1_6300 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_6301 = torch.constant.int 1 + %int2_6302 = torch.constant.int 2 + %5919 = torch.aten.transpose.int %5918, %int1_6301, %int2_6302 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_6303 = torch.constant.int 32 + %5920 = torch.aten.floor_divide.Scalar %arg2, %int32_6303 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> %int1_6304 = torch.constant.int 1 - %int4096_6305 = torch.constant.int 4096 - %5050 = torch.prim.ListConstruct %int4_6303, %int1_6304, %int4096_6305 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5051 = torch.aten.view %5049, %5050 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_6306 = torch.constant.int 1 - %5052 = torch.aten.add.Tensor %4879, %5051, %int1_6306 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_6307 = torch.constant.int 6 - %5053 = torch.prims.convert_element_type %5052, %int6_6307 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_6308 = torch.constant.int 2 - %5054 = torch.aten.pow.Tensor_Scalar %5053, %int2_6308 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_6309 = torch.constant.int -1 - %5055 = torch.prim.ListConstruct %int-1_6309 : (!torch.int) -> !torch.list - %true_6310 = torch.constant.bool true - %none_6311 = torch.constant.none - %5056 = torch.aten.mean.dim %5054, %5055, %true_6310, %none_6311 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_6312 = torch.constant.float 9.9999997473787516E-6 + %5921 = torch.aten.unsqueeze %5920, %int1_6304 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_6305 = torch.constant.int 1 + %false_6306 = torch.constant.bool false + %5922 = torch.aten.gather %arg3, %int1_6305, %5921, %false_6306 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_6307 = torch.constant.int 4 + %int1_6308 = torch.constant.int 1 + %int1_6309 = torch.constant.int 1 + %5923 = torch.prim.ListConstruct %int4_6307, %int1_6308, %int1_6309 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5924 = torch.aten.view %5922, %5923 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_6310 = torch.constant.int 32 + %5925 = torch.aten.remainder.Scalar %arg2, %int32_6310 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_6311 = torch.constant.int 4 + %int1_6312 = torch.constant.int 1 %int1_6313 = torch.constant.int 1 - %5057 = torch.aten.add.Scalar %5056, %float9.999990e-06_6312, %int1_6313 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %5058 = torch.aten.rsqrt %5057 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %5059 = torch.aten.mul.Tensor %5053, %5058 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_6314 = torch.constant.int 5 - %5060 = torch.prims.convert_element_type %5059, %int5_6314 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %5061 = torch.aten.mul.Tensor %250, %5060 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_6315 = torch.constant.int 5 - %5062 = torch.prims.convert_element_type %5061, %int5_6315 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_6316 = torch.constant.int -2 - %int-1_6317 = torch.constant.int -1 - %5063 = torch.aten.transpose.int %251, %int-2_6316, %int-1_6317 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6318 = torch.constant.int 4 - %int4096_6319 = torch.constant.int 4096 - %5064 = torch.prim.ListConstruct %int4_6318, %int4096_6319 : (!torch.int, !torch.int) -> !torch.list - %5065 = torch.aten.view %5062, %5064 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5066 = torch.aten.mm %5065, %5063 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_6320 = torch.constant.int 4 - %int1_6321 = torch.constant.int 1 - %int14336_6322 = torch.constant.int 14336 - %5067 = torch.prim.ListConstruct %int4_6320, %int1_6321, %int14336_6322 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5068 = torch.aten.view %5066, %5067 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %5069 = torch.aten.silu %5068 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_6323 = torch.constant.int -2 - %int-1_6324 = torch.constant.int -1 - %5070 = torch.aten.transpose.int %252, %int-2_6323, %int-1_6324 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6325 = torch.constant.int 4 - %int4096_6326 = torch.constant.int 4096 - %5071 = torch.prim.ListConstruct %int4_6325, %int4096_6326 : (!torch.int, !torch.int) -> !torch.list - %5072 = torch.aten.view %5062, %5071 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5073 = torch.aten.mm %5072, %5070 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_6327 = torch.constant.int 4 - %int1_6328 = torch.constant.int 1 - %int14336_6329 = torch.constant.int 14336 - %5074 = torch.prim.ListConstruct %int4_6327, %int1_6328, %int14336_6329 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5075 = torch.aten.view %5073, %5074 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %5076 = torch.aten.mul.Tensor %5069, %5075 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_6330 = torch.constant.int -2 - %int-1_6331 = torch.constant.int -1 - %5077 = torch.aten.transpose.int %253, %int-2_6330, %int-1_6331 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_6332 = torch.constant.int 4 - %int14336_6333 = torch.constant.int 14336 - %5078 = torch.prim.ListConstruct %int4_6332, %int14336_6333 : (!torch.int, !torch.int) -> !torch.list - %5079 = torch.aten.view %5076, %5078 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %5080 = torch.aten.mm %5079, %5077 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_6334 = torch.constant.int 4 - %int1_6335 = torch.constant.int 1 - %int4096_6336 = torch.constant.int 4096 - %5081 = torch.prim.ListConstruct %int4_6334, %int1_6335, %int4096_6336 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5082 = torch.aten.view %5080, %5081 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_6337 = torch.constant.int 1 - %5083 = torch.aten.add.Tensor %5052, %5082, %int1_6337 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_6338 = torch.constant.int 6 - %5084 = torch.prims.convert_element_type %5083, %int6_6338 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_6339 = torch.constant.int 2 - %5085 = torch.aten.pow.Tensor_Scalar %5084, %int2_6339 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_6340 = torch.constant.int -1 - %5086 = torch.prim.ListConstruct %int-1_6340 : (!torch.int) -> !torch.list - %true_6341 = torch.constant.bool true - %none_6342 = torch.constant.none - %5087 = torch.aten.mean.dim %5085, %5086, %true_6341, %none_6342 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_6343 = torch.constant.float 9.9999997473787516E-6 - %int1_6344 = torch.constant.int 1 - %5088 = torch.aten.add.Scalar %5087, %float9.999990e-06_6343, %int1_6344 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %5089 = torch.aten.rsqrt %5088 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %5090 = torch.aten.mul.Tensor %5084, %5089 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_6345 = torch.constant.int 5 - %5091 = torch.prims.convert_element_type %5090, %int5_6345 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %5092 = torch.aten.mul.Tensor %254, %5091 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_6346 = torch.constant.int 5 - %5093 = torch.prims.convert_element_type %5092, %int5_6346 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_6347 = torch.constant.int -2 - %int-1_6348 = torch.constant.int -1 - %5094 = torch.aten.transpose.int %255, %int-2_6347, %int-1_6348 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6349 = torch.constant.int 4 - %int4096_6350 = torch.constant.int 4096 - %5095 = torch.prim.ListConstruct %int4_6349, %int4096_6350 : (!torch.int, !torch.int) -> !torch.list - %5096 = torch.aten.view %5093, %5095 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5097 = torch.aten.mm %5096, %5094 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_6351 = torch.constant.int 4 - %int1_6352 = torch.constant.int 1 - %int4096_6353 = torch.constant.int 4096 - %5098 = torch.prim.ListConstruct %int4_6351, %int1_6352, %int4096_6353 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5099 = torch.aten.view %5097, %5098 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_6354 = torch.constant.int -2 - %int-1_6355 = torch.constant.int -1 - %5100 = torch.aten.transpose.int %256, %int-2_6354, %int-1_6355 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6356 = torch.constant.int 4 - %int4096_6357 = torch.constant.int 4096 - %5101 = torch.prim.ListConstruct %int4_6356, %int4096_6357 : (!torch.int, !torch.int) -> !torch.list - %5102 = torch.aten.view %5093, %5101 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5103 = torch.aten.mm %5102, %5100 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_6358 = torch.constant.int 4 - %int1_6359 = torch.constant.int 1 - %int1024_6360 = torch.constant.int 1024 - %5104 = torch.prim.ListConstruct %int4_6358, %int1_6359, %int1024_6360 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5105 = torch.aten.view %5103, %5104 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_6361 = torch.constant.int -2 - %int-1_6362 = torch.constant.int -1 - %5106 = torch.aten.transpose.int %257, %int-2_6361, %int-1_6362 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6363 = torch.constant.int 4 - %int4096_6364 = torch.constant.int 4096 - %5107 = torch.prim.ListConstruct %int4_6363, %int4096_6364 : (!torch.int, !torch.int) -> !torch.list - %5108 = torch.aten.view %5093, %5107 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5109 = torch.aten.mm %5108, %5106 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_6365 = torch.constant.int 4 + %5926 = torch.prim.ListConstruct %int4_6311, %int1_6312, %int1_6313 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5927 = torch.aten.view %5925, %5926 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_6314 = torch.constant.int 8 + %none_6315 = torch.constant.none + %none_6316 = torch.constant.none + %cpu_6317 = torch.constant.device "cpu" + %false_6318 = torch.constant.bool false + %5928 = torch.aten.arange %int8_6314, %none_6315, %none_6316, %cpu_6317, %false_6318 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_6319 = torch.constant.int 1 + %int1_6320 = torch.constant.int 1 + %int8_6321 = torch.constant.int 8 + %5929 = torch.prim.ListConstruct %int1_6319, %int1_6320, %int8_6321 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5930 = torch.aten.view %5928, %5929 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_6322 = torch.constant.none + %5931 = torch.aten.clone %356, %none_6322 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5932 = torch.aten.detach %5931 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5933 = torch.aten.detach %5932 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5934 = torch.aten.detach %5933 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_6323 = torch.constant.int 1 + %int1_6324 = torch.constant.int 1 + %int1_6325 = torch.constant.int 1 + %5935 = torch.prim.ListConstruct %int1_6323, %int1_6324, %int1_6325 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5936 = torch.aten.view %5934, %5935 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_6326 = torch.constant.int 32 + %5937 = torch.aten.mul.Scalar %5924, %int32_6326 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int25 = torch.constant.int 25 + %int1_6327 = torch.constant.int 1 + %5938 = torch.aten.add.Scalar %5937, %int25, %int1_6327 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_6328 = torch.constant.int 2 + %5939 = torch.aten.mul.Scalar %5938, %int2_6328 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_6329 = torch.constant.int 1 + %5940 = torch.aten.add.Tensor %5939, %5936, %int1_6329 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_6330 = torch.constant.int 8 + %5941 = torch.aten.mul.Scalar %5940, %int8_6330 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_6331 = torch.constant.int 1 + %5942 = torch.aten.add.Tensor %5941, %5930, %int1_6331 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_6332 = torch.constant.int 32 + %5943 = torch.aten.mul.Scalar %5942, %int32_6332 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_6333 = torch.constant.int 1 + %5944 = torch.aten.add.Tensor %5943, %5927, %int1_6333 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_6334 = torch.constant.int 5 + %5945 = torch.prims.convert_element_type %5919, %int5_6334 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_6335 = torch.constant.int 32 + %int2_6336 = torch.constant.int 2 + %int8_6337 = torch.constant.int 8 + %int32_6338 = torch.constant.int 32 + %int128_6339 = torch.constant.int 128 + %5946 = torch.prim.ListConstruct %456, %int32_6335, %int2_6336, %int8_6337, %int32_6338, %int128_6339 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5947 = torch.aten.view %5767, %5946 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5947, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_6340 = torch.constant.int 128 + %5948 = torch.prim.ListConstruct %596, %int128_6340 : (!torch.int, !torch.int) -> !torch.list + %5949 = torch.aten.view %5947, %5948 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5949, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %5950 = torch.prim.ListConstruct %5944 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_6341 = torch.constant.bool false + %5951 = torch.aten.index_put %5949, %5950, %5945, %false_6341 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5951, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_6342 = torch.constant.int 32 + %int2_6343 = torch.constant.int 2 + %int8_6344 = torch.constant.int 8 + %int32_6345 = torch.constant.int 32 + %int128_6346 = torch.constant.int 128 + %5952 = torch.prim.ListConstruct %456, %int32_6342, %int2_6343, %int8_6344, %int32_6345, %int128_6346 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5953 = torch.aten.view %5951, %5952 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5953, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_6347 = torch.constant.int 2097152 + %5954 = torch.prim.ListConstruct %456, %int2097152_6347 : (!torch.int, !torch.int) -> !torch.list + %5955 = torch.aten.view %5953, %5954 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5955, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_6348 = torch.constant.int 32 + %int2_6349 = torch.constant.int 2 + %int8_6350 = torch.constant.int 8 + %int32_6351 = torch.constant.int 32 + %int128_6352 = torch.constant.int 128 + %5956 = torch.prim.ListConstruct %456, %int32_6348, %int2_6349, %int8_6350, %int32_6351, %int128_6352 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5957 = torch.aten.view %5955, %5956 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5957, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_6353 = torch.constant.int 128 + %5958 = torch.prim.ListConstruct %596, %int128_6353 : (!torch.int, !torch.int) -> !torch.list + %5959 = torch.aten.view %5957, %5958 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5959, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_6354 = torch.constant.none + %5960 = torch.aten.clone %357, %none_6354 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5961 = torch.aten.detach %5960 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5962 = torch.aten.detach %5961 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5963 = torch.aten.detach %5962 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_6355 = torch.constant.int 1 + %int1_6356 = torch.constant.int 1 + %int1_6357 = torch.constant.int 1 + %5964 = torch.prim.ListConstruct %int1_6355, %int1_6356, %int1_6357 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5965 = torch.aten.view %5963, %5964 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_6358 = torch.constant.int 32 + %5966 = torch.aten.mul.Scalar %5924, %int32_6358 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int25_6359 = torch.constant.int 25 + %int1_6360 = torch.constant.int 1 + %5967 = torch.aten.add.Scalar %5966, %int25_6359, %int1_6360 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_6361 = torch.constant.int 2 + %5968 = torch.aten.mul.Scalar %5967, %int2_6361 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_6362 = torch.constant.int 1 + %5969 = torch.aten.add.Tensor %5968, %5965, %int1_6362 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_6363 = torch.constant.int 8 + %5970 = torch.aten.mul.Scalar %5969, %int8_6363 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_6364 = torch.constant.int 1 + %5971 = torch.aten.add.Tensor %5970, %5930, %int1_6364 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_6365 = torch.constant.int 32 + %5972 = torch.aten.mul.Scalar %5971, %int32_6365 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_6366 = torch.constant.int 1 - %int1024_6367 = torch.constant.int 1024 - %5110 = torch.prim.ListConstruct %int4_6365, %int1_6366, %int1024_6367 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5111 = torch.aten.view %5109, %5110 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_6368 = torch.constant.int 4 - %int1_6369 = torch.constant.int 1 - %int32_6370 = torch.constant.int 32 - %int128_6371 = torch.constant.int 128 - %5112 = torch.prim.ListConstruct %int4_6368, %int1_6369, %int32_6370, %int128_6371 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5113 = torch.aten.view %5099, %5112 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_6372 = torch.constant.int 4 - %int1_6373 = torch.constant.int 1 - %int8_6374 = torch.constant.int 8 - %int128_6375 = torch.constant.int 128 - %5114 = torch.prim.ListConstruct %int4_6372, %int1_6373, %int8_6374, %int128_6375 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5115 = torch.aten.view %5105, %5114 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_6376 = torch.constant.int 4 - %int1_6377 = torch.constant.int 1 - %int8_6378 = torch.constant.int 8 - %int128_6379 = torch.constant.int 128 - %5116 = torch.prim.ListConstruct %int4_6376, %int1_6377, %int8_6378, %int128_6379 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5117 = torch.aten.view %5111, %5116 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_6380 = torch.constant.int 6 - %5118 = torch.prims.convert_element_type %5113, %int6_6380 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %5119 = torch_c.to_builtin_tensor %5118 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %5120 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %5121 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%5119, %5120) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %5122 = torch_c.from_builtin_tensor %5121 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_6381 = torch.constant.int 5 - %5123 = torch.prims.convert_element_type %5122, %int5_6381 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_6382 = torch.constant.int 6 - %5124 = torch.prims.convert_element_type %5115, %int6_6382 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %5125 = torch_c.to_builtin_tensor %5124 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %5126 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %5127 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%5125, %5126) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %5128 = torch_c.from_builtin_tensor %5127 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_6383 = torch.constant.int 5 - %5129 = torch.prims.convert_element_type %5128, %int5_6383 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_6384 = torch.constant.int 32 - %5130 = torch.aten.floor_divide.Scalar %arg2, %int32_6384 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_6385 = torch.constant.int 1 - %5131 = torch.aten.unsqueeze %5130, %int1_6385 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6386 = torch.constant.int 1 - %false_6387 = torch.constant.bool false - %5132 = torch.aten.gather %arg3, %int1_6386, %5131, %false_6387 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_6388 = torch.constant.int 32 - %5133 = torch.aten.remainder.Scalar %arg2, %int32_6388 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_6389 = torch.constant.int 1 - %5134 = torch.aten.unsqueeze %5133, %int1_6389 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_6390 = torch.constant.none - %5135 = torch.aten.clone %258, %none_6390 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_6391 = torch.constant.int 0 - %5136 = torch.aten.unsqueeze %5135, %int0_6391 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_6392 = torch.constant.int 4 - %int1_6393 = torch.constant.int 1 - %5137 = torch.prim.ListConstruct %int4_6392, %int1_6393 : (!torch.int, !torch.int) -> !torch.list - %int1_6394 = torch.constant.int 1 - %int1_6395 = torch.constant.int 1 - %5138 = torch.prim.ListConstruct %int1_6394, %int1_6395 : (!torch.int, !torch.int) -> !torch.list + %5973 = torch.aten.add.Tensor %5972, %5927, %int1_6366 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_6367 = torch.constant.int 5 + %5974 = torch.prims.convert_element_type %5899, %int5_6367 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %5975 = torch.prim.ListConstruct %5973 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_6368 = torch.constant.bool false + %5976 = torch.aten.index_put %5959, %5975, %5974, %false_6368 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %5976, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_6369 = torch.constant.int 32 + %int2_6370 = torch.constant.int 2 + %int8_6371 = torch.constant.int 8 + %int32_6372 = torch.constant.int 32 + %int128_6373 = torch.constant.int 128 + %5977 = torch.prim.ListConstruct %456, %int32_6369, %int2_6370, %int8_6371, %int32_6372, %int128_6373 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5978 = torch.aten.view %5976, %5977 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5978, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_6374 = torch.constant.int 2097152 + %5979 = torch.prim.ListConstruct %456, %int2097152_6374 : (!torch.int, !torch.int) -> !torch.list + %5980 = torch.aten.view %5978, %5979 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %5980, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_6375 = torch.constant.none + %5981 = torch.aten.clone %358, %none_6375 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5982 = torch.aten.detach %5981 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5983 = torch.aten.detach %5982 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5984 = torch.aten.detach %5983 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_6376 = torch.constant.none + %5985 = torch.aten.clone %359, %none_6376 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5986 = torch.aten.detach %5985 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5987 = torch.aten.detach %5986 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5988 = torch.aten.detach %5987 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_6377 = torch.constant.none + %5989 = torch.aten.clone %360, %none_6377 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %5990 = torch.aten.detach %5989 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5991 = torch.aten.detach %5990 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %5992 = torch.aten.detach %5991 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_6378 = torch.constant.int 32 + %int2_6379 = torch.constant.int 2 + %int8_6380 = torch.constant.int 8 + %int32_6381 = torch.constant.int 32 + %int128_6382 = torch.constant.int 128 + %5993 = torch.prim.ListConstruct %456, %int32_6378, %int2_6379, %int8_6380, %int32_6381, %int128_6382 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5994 = torch.aten.view %5980, %5993 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %5994, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %5995 = torch_c.to_builtin_tensor %5994 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %5996 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_6383 = tensor.cast %5996 : tensor<4x?xi64> to tensor + %5997 = torch_c.to_builtin_tensor %5984 : !torch.vtensor<[],si64> -> tensor + %5998 = torch_c.to_builtin_tensor %5988 : !torch.vtensor<[],si64> -> tensor + %5999 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%5995, %cast_6383, %5997, %5998) : (tensor, tensor, tensor, tensor) -> tensor + %cast_6384 = tensor.cast %5999 : tensor to tensor<4x?x8x32x128xf16> + %6000 = torch_c.from_builtin_tensor %cast_6384 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %6000, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %6001 = torch_c.to_builtin_tensor %5994 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %6002 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_6385 = tensor.cast %6002 : tensor<4x?xi64> to tensor + %6003 = torch_c.to_builtin_tensor %5984 : !torch.vtensor<[],si64> -> tensor + %6004 = torch_c.to_builtin_tensor %5992 : !torch.vtensor<[],si64> -> tensor + %6005 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%6001, %cast_6385, %6003, %6004) : (tensor, tensor, tensor, tensor) -> tensor + %cast_6386 = tensor.cast %6005 : tensor to tensor<4x?x8x32x128xf16> + %6006 = torch_c.from_builtin_tensor %cast_6386 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %6006, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_6387 = torch.constant.int 2 + %int3_6388 = torch.constant.int 3 + %6007 = torch.aten.transpose.int %6000, %int2_6387, %int3_6388 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6007, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_6389 = torch.constant.int 0 + %6008 = torch.aten.clone %6007, %int0_6389 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6008, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_6390 = torch.constant.int 4 + %int8_6391 = torch.constant.int 8 + %int128_6392 = torch.constant.int 128 + %6009 = torch.prim.ListConstruct %int4_6390, %457, %int8_6391, %int128_6392 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6010 = torch.aten._unsafe_view %6008, %6009 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6010, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_6393 = torch.constant.int 2 + %int3_6394 = torch.constant.int 3 + %6011 = torch.aten.transpose.int %6006, %int2_6393, %int3_6394 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6011, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_6395 = torch.constant.int 0 + %6012 = torch.aten.clone %6011, %int0_6395 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6012, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int4_6396 = torch.constant.int 4 - %int0_6397 = torch.constant.int 0 - %cpu_6398 = torch.constant.device "cpu" - %false_6399 = torch.constant.bool false - %5139 = torch.aten.empty_strided %5137, %5138, %int4_6396, %int0_6397, %cpu_6398, %false_6399 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int23 = torch.constant.int 23 - %5140 = torch.aten.fill.Scalar %5139, %int23 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int8_6397 = torch.constant.int 8 + %int128_6398 = torch.constant.int 128 + %6013 = torch.prim.ListConstruct %int4_6396, %457, %int8_6397, %int128_6398 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6014 = torch.aten._unsafe_view %6012, %6013 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6014, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_6399 = torch.constant.int -2 + %6015 = torch.aten.unsqueeze %6010, %int-2_6399 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6015, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> %int4_6400 = torch.constant.int 4 - %int1_6401 = torch.constant.int 1 - %5141 = torch.prim.ListConstruct %int4_6400, %int1_6401 : (!torch.int, !torch.int) -> !torch.list - %5142 = torch.aten.repeat %5136, %5141 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_6402 = torch.constant.int 32 - %5143 = torch.aten.mul.Scalar %5132, %int32_6402 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6403 = torch.constant.int 1 - %5144 = torch.aten.add.Tensor %5143, %5140, %int1_6403 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_6404 = torch.constant.int 2 - %5145 = torch.aten.mul.Scalar %5144, %int2_6404 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6405 = torch.constant.int 1 - %5146 = torch.aten.add.Tensor %5145, %5142, %int1_6405 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_6406 = torch.constant.int 32 - %5147 = torch.aten.mul.Scalar %5146, %int32_6406 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6407 = torch.constant.int 1 - %5148 = torch.aten.add.Tensor %5147, %5134, %int1_6407 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_6408 = torch.constant.int 32 - %int2_6409 = torch.constant.int 2 - %int32_6410 = torch.constant.int 32 + %int8_6401 = torch.constant.int 8 + %int4_6402 = torch.constant.int 4 + %int128_6403 = torch.constant.int 128 + %6016 = torch.prim.ListConstruct %int4_6400, %457, %int8_6401, %int4_6402, %int128_6403 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_6404 = torch.constant.bool false + %6017 = torch.aten.expand %6015, %6016, %false_6404 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6017, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_6405 = torch.constant.int 0 + %6018 = torch.aten.clone %6017, %int0_6405 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6018, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_6406 = torch.constant.int 4 + %int32_6407 = torch.constant.int 32 + %int128_6408 = torch.constant.int 128 + %6019 = torch.prim.ListConstruct %int4_6406, %457, %int32_6407, %int128_6408 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6020 = torch.aten._unsafe_view %6018, %6019 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6020, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_6409 = torch.constant.int -2 + %6021 = torch.aten.unsqueeze %6014, %int-2_6409 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6021, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_6410 = torch.constant.int 4 %int8_6411 = torch.constant.int 8 - %int128_6412 = torch.constant.int 128 - %5149 = torch.prim.ListConstruct %437, %int32_6408, %int2_6409, %int32_6410, %int8_6411, %int128_6412 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5150 = torch.aten.view %4986, %5149 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5150, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_6413 = torch.constant.int 32 - %5151 = torch.aten.mul.int %437, %int32_6413 : !torch.int, !torch.int -> !torch.int - %int2_6414 = torch.constant.int 2 - %5152 = torch.aten.mul.int %5151, %int2_6414 : !torch.int, !torch.int -> !torch.int - %int32_6415 = torch.constant.int 32 - %5153 = torch.aten.mul.int %5152, %int32_6415 : !torch.int, !torch.int -> !torch.int - %int8_6416 = torch.constant.int 8 - %int128_6417 = torch.constant.int 128 - %5154 = torch.prim.ListConstruct %5153, %int8_6416, %int128_6417 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5155 = torch.aten.view %5150, %5154 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5155, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %5156 = torch.prim.ListConstruct %5148 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_6418 = torch.constant.bool false - %5157 = torch.aten.index_put %5155, %5156, %5129, %false_6418 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5157, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_6419 = torch.constant.int 32 + %int4_6412 = torch.constant.int 4 + %int128_6413 = torch.constant.int 128 + %6022 = torch.prim.ListConstruct %int4_6410, %457, %int8_6411, %int4_6412, %int128_6413 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_6414 = torch.constant.bool false + %6023 = torch.aten.expand %6021, %6022, %false_6414 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6023, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_6415 = torch.constant.int 0 + %6024 = torch.aten.clone %6023, %int0_6415 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6024, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_6416 = torch.constant.int 4 + %int32_6417 = torch.constant.int 32 + %int128_6418 = torch.constant.int 128 + %6025 = torch.prim.ListConstruct %int4_6416, %457, %int32_6417, %int128_6418 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6026 = torch.aten._unsafe_view %6024, %6025 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6026, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_6419 = torch.constant.int 1 %int2_6420 = torch.constant.int 2 - %int32_6421 = torch.constant.int 32 - %int8_6422 = torch.constant.int 8 - %int128_6423 = torch.constant.int 128 - %5158 = torch.prim.ListConstruct %437, %int32_6419, %int2_6420, %int32_6421, %int8_6422, %int128_6423 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5159 = torch.aten.view %5157, %5158 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5159, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6424 = torch.constant.int 2097152 - %5160 = torch.prim.ListConstruct %437, %int2097152_6424 : (!torch.int, !torch.int) -> !torch.list - %5161 = torch.aten.view %5159, %5160 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5161, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_6425 = torch.constant.int 32 - %int2_6426 = torch.constant.int 2 - %int32_6427 = torch.constant.int 32 - %int8_6428 = torch.constant.int 8 - %int128_6429 = torch.constant.int 128 - %5162 = torch.prim.ListConstruct %437, %int32_6425, %int2_6426, %int32_6427, %int8_6428, %int128_6429 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5163 = torch.aten.view %5161, %5162 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5163, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_6430 = torch.constant.int 8 - %int128_6431 = torch.constant.int 128 - %5164 = torch.prim.ListConstruct %5153, %int8_6430, %int128_6431 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5165 = torch.aten.view %5163, %5164 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5165, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_6432 = torch.constant.int 32 - %5166 = torch.aten.floor_divide.Scalar %arg2, %int32_6432 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_6433 = torch.constant.int 1 - %5167 = torch.aten.unsqueeze %5166, %int1_6433 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6434 = torch.constant.int 1 - %false_6435 = torch.constant.bool false - %5168 = torch.aten.gather %arg3, %int1_6434, %5167, %false_6435 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_6436 = torch.constant.int 32 - %5169 = torch.aten.remainder.Scalar %arg2, %int32_6436 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_6437 = torch.constant.int 1 - %5170 = torch.aten.unsqueeze %5169, %int1_6437 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_6438 = torch.constant.none - %5171 = torch.aten.clone %259, %none_6438 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_6439 = torch.constant.int 0 - %5172 = torch.aten.unsqueeze %5171, %int0_6439 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_6440 = torch.constant.int 4 + %6027 = torch.aten.transpose.int %5909, %int1_6419, %int2_6420 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_6421 = torch.constant.int 1 + %int2_6422 = torch.constant.int 2 + %6028 = torch.aten.transpose.int %6020, %int1_6421, %int2_6422 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6028, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_6423 = torch.constant.int 1 + %int2_6424 = torch.constant.int 2 + %6029 = torch.aten.transpose.int %6026, %int1_6423, %int2_6424 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6029, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_6425 = torch.constant.float 0.000000e+00 + %false_6426 = torch.constant.bool false + %none_6427 = torch.constant.none + %6030:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6027, %6028, %6029, %float0.000000e00_6425, %false_6426, %470, %none_6427) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_6428 = torch.constant.int 1 + %int2_6429 = torch.constant.int 2 + %6031 = torch.aten.transpose.int %6030#0, %int1_6428, %int2_6429 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_6430 = torch.constant.int 4 + %int1_6431 = torch.constant.int 1 + %int4096_6432 = torch.constant.int 4096 + %6032 = torch.prim.ListConstruct %int4_6430, %int1_6431, %int4096_6432 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6033 = torch.aten.view %6031, %6032 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_6433 = torch.constant.int -2 + %int-1_6434 = torch.constant.int -1 + %6034 = torch.aten.transpose.int %361, %int-2_6433, %int-1_6434 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_6435 = torch.constant.int 5 + %6035 = torch.prims.convert_element_type %6034, %int5_6435 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_6436 = torch.constant.int 4 + %int4096_6437 = torch.constant.int 4096 + %6036 = torch.prim.ListConstruct %int4_6436, %int4096_6437 : (!torch.int, !torch.int) -> !torch.list + %6037 = torch.aten.view %6033, %6036 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6038 = torch.aten.mm %6037, %6035 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_6438 = torch.constant.int 4 + %int1_6439 = torch.constant.int 1 + %int4096_6440 = torch.constant.int 4096 + %6039 = torch.prim.ListConstruct %int4_6438, %int1_6439, %int4096_6440 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6040 = torch.aten.view %6038, %6039 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_6441 = torch.constant.int 1 - %5173 = torch.prim.ListConstruct %int4_6440, %int1_6441 : (!torch.int, !torch.int) -> !torch.list - %int1_6442 = torch.constant.int 1 - %int1_6443 = torch.constant.int 1 - %5174 = torch.prim.ListConstruct %int1_6442, %int1_6443 : (!torch.int, !torch.int) -> !torch.list - %int4_6444 = torch.constant.int 4 - %int0_6445 = torch.constant.int 0 - %cpu_6446 = torch.constant.device "cpu" - %false_6447 = torch.constant.bool false - %5175 = torch.aten.empty_strided %5173, %5174, %int4_6444, %int0_6445, %cpu_6446, %false_6447 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int23_6448 = torch.constant.int 23 - %5176 = torch.aten.fill.Scalar %5175, %int23_6448 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_6449 = torch.constant.int 4 - %int1_6450 = torch.constant.int 1 - %5177 = torch.prim.ListConstruct %int4_6449, %int1_6450 : (!torch.int, !torch.int) -> !torch.list - %5178 = torch.aten.repeat %5172, %5177 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_6451 = torch.constant.int 32 - %5179 = torch.aten.mul.Scalar %5168, %int32_6451 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6452 = torch.constant.int 1 - %5180 = torch.aten.add.Tensor %5179, %5176, %int1_6452 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_6453 = torch.constant.int 2 - %5181 = torch.aten.mul.Scalar %5180, %int2_6453 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6454 = torch.constant.int 1 - %5182 = torch.aten.add.Tensor %5181, %5178, %int1_6454 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_6455 = torch.constant.int 32 - %5183 = torch.aten.mul.Scalar %5182, %int32_6455 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6456 = torch.constant.int 1 - %5184 = torch.aten.add.Tensor %5183, %5170, %int1_6456 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %5185 = torch.prim.ListConstruct %5184 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_6457 = torch.constant.bool false - %5186 = torch.aten.index_put %5165, %5185, %5117, %false_6457 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5186, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_6458 = torch.constant.int 32 - %int2_6459 = torch.constant.int 2 - %int32_6460 = torch.constant.int 32 - %int8_6461 = torch.constant.int 8 - %int128_6462 = torch.constant.int 128 - %5187 = torch.prim.ListConstruct %437, %int32_6458, %int2_6459, %int32_6460, %int8_6461, %int128_6462 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5188 = torch.aten.view %5186, %5187 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5188, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6463 = torch.constant.int 2097152 - %5189 = torch.prim.ListConstruct %437, %int2097152_6463 : (!torch.int, !torch.int) -> !torch.list - %5190 = torch.aten.view %5188, %5189 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5190, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %6041 = torch.aten.add.Tensor %5862, %6040, %int1_6441 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_6442 = torch.constant.int 6 + %6042 = torch.prims.convert_element_type %6041, %int6_6442 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_6443 = torch.constant.int 2 + %6043 = torch.aten.pow.Tensor_Scalar %6042, %int2_6443 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_6444 = torch.constant.int -1 + %6044 = torch.prim.ListConstruct %int-1_6444 : (!torch.int) -> !torch.list + %true_6445 = torch.constant.bool true + %none_6446 = torch.constant.none + %6045 = torch.aten.mean.dim %6043, %6044, %true_6445, %none_6446 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_6447 = torch.constant.float 9.9999997473787516E-6 + %int1_6448 = torch.constant.int 1 + %6046 = torch.aten.add.Scalar %6045, %float9.999990e-06_6447, %int1_6448 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %6047 = torch.aten.rsqrt %6046 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %6048 = torch.aten.mul.Tensor %6042, %6047 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_6449 = torch.constant.int 5 + %6049 = torch.prims.convert_element_type %6048, %int5_6449 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %6050 = torch.aten.mul.Tensor %362, %6049 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_6450 = torch.constant.int 5 + %6051 = torch.prims.convert_element_type %6050, %int5_6450 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_6451 = torch.constant.int -2 + %int-1_6452 = torch.constant.int -1 + %6052 = torch.aten.transpose.int %363, %int-2_6451, %int-1_6452 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_6453 = torch.constant.int 5 + %6053 = torch.prims.convert_element_type %6052, %int5_6453 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_6454 = torch.constant.int 4 + %int4096_6455 = torch.constant.int 4096 + %6054 = torch.prim.ListConstruct %int4_6454, %int4096_6455 : (!torch.int, !torch.int) -> !torch.list + %6055 = torch.aten.view %6051, %6054 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6056 = torch.aten.mm %6055, %6053 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_6456 = torch.constant.int 4 + %int1_6457 = torch.constant.int 1 + %int14336_6458 = torch.constant.int 14336 + %6057 = torch.prim.ListConstruct %int4_6456, %int1_6457, %int14336_6458 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6058 = torch.aten.view %6056, %6057 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %6059 = torch.aten.silu %6058 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_6459 = torch.constant.int -2 + %int-1_6460 = torch.constant.int -1 + %6060 = torch.aten.transpose.int %364, %int-2_6459, %int-1_6460 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_6461 = torch.constant.int 5 + %6061 = torch.prims.convert_element_type %6060, %int5_6461 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_6462 = torch.constant.int 4 + %int4096_6463 = torch.constant.int 4096 + %6062 = torch.prim.ListConstruct %int4_6462, %int4096_6463 : (!torch.int, !torch.int) -> !torch.list + %6063 = torch.aten.view %6051, %6062 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6064 = torch.aten.mm %6063, %6061 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> %int4_6464 = torch.constant.int 4 - %5191 = torch.prim.ListConstruct %int4_6464, %358 : (!torch.int, !torch.int) -> !torch.list %int1_6465 = torch.constant.int 1 - %5192 = torch.prim.ListConstruct %358, %int1_6465 : (!torch.int, !torch.int) -> !torch.list - %int4_6466 = torch.constant.int 4 - %int0_6467 = torch.constant.int 0 - %cpu_6468 = torch.constant.device "cpu" - %false_6469 = torch.constant.bool false - %5193 = torch.aten.empty_strided %5191, %5192, %int4_6466, %int0_6467, %cpu_6468, %false_6469 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5193, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int23_6470 = torch.constant.int 23 - %5194 = torch.aten.fill.Scalar %5193, %int23_6470 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5194, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_6471 = torch.constant.int 32 - %5195 = torch.aten.mul.Scalar %arg3, %int32_6471 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5195, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_6472 = torch.constant.int 1 - %5196 = torch.aten.add.Tensor %5195, %5194, %int1_6472 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5196, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_6473 = torch.constant.int 4 - %5197 = torch.aten.mul.int %int4_6473, %358 : !torch.int, !torch.int -> !torch.int - %5198 = torch.prim.ListConstruct %5197 : (!torch.int) -> !torch.list - %5199 = torch.aten.view %5196, %5198 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %5199, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_6474 = torch.constant.int 32 - %int2_6475 = torch.constant.int 2 - %int32_6476 = torch.constant.int 32 - %int8_6477 = torch.constant.int 8 - %int128_6478 = torch.constant.int 128 - %5200 = torch.prim.ListConstruct %437, %int32_6474, %int2_6475, %int32_6476, %int8_6477, %int128_6478 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5201 = torch.aten.view %5190, %5200 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5201, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_6479 = torch.constant.int 32 - %5202 = torch.aten.mul.int %437, %int32_6479 : !torch.int, !torch.int -> !torch.int - %int2_6480 = torch.constant.int 2 - %int32_6481 = torch.constant.int 32 - %int8_6482 = torch.constant.int 8 - %int128_6483 = torch.constant.int 128 - %5203 = torch.prim.ListConstruct %5202, %int2_6480, %int32_6481, %int8_6482, %int128_6483 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5204 = torch.aten.view %5201, %5203 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %5204, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_6484 = torch.constant.int 0 - %5205 = torch.aten.index_select %5204, %int0_6484, %5199 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %5205, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_6485 = torch.constant.int 4 - %int2_6486 = torch.constant.int 2 - %int32_6487 = torch.constant.int 32 - %int8_6488 = torch.constant.int 8 - %int128_6489 = torch.constant.int 128 - %5206 = torch.prim.ListConstruct %int4_6485, %358, %int2_6486, %int32_6487, %int8_6488, %int128_6489 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5207 = torch.aten.view %5205, %5206 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5207, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_6490 = torch.constant.int 0 - %int0_6491 = torch.constant.int 0 - %int9223372036854775807_6492 = torch.constant.int 9223372036854775807 - %int1_6493 = torch.constant.int 1 - %5208 = torch.aten.slice.Tensor %5207, %int0_6490, %int0_6491, %int9223372036854775807_6492, %int1_6493 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5208, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_6494 = torch.constant.int 1 - %int0_6495 = torch.constant.int 0 - %int9223372036854775807_6496 = torch.constant.int 9223372036854775807 - %int1_6497 = torch.constant.int 1 - %5209 = torch.aten.slice.Tensor %5208, %int1_6494, %int0_6495, %int9223372036854775807_6496, %int1_6497 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5209, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_6498 = torch.constant.int 2 - %int0_6499 = torch.constant.int 0 - %5210 = torch.aten.select.int %5209, %int2_6498, %int0_6499 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5210, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_6500 = torch.constant.int 32 - %5211 = torch.aten.mul.int %358, %int32_6500 : !torch.int, !torch.int -> !torch.int - %int2_6501 = torch.constant.int 2 - %int0_6502 = torch.constant.int 0 - %int1_6503 = torch.constant.int 1 - %5212 = torch.aten.slice.Tensor %5210, %int2_6501, %int0_6502, %5211, %int1_6503 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5212, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_6504 = torch.constant.int 0 - %5213 = torch.aten.clone %5212, %int0_6504 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5213, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_6505 = torch.constant.int 1 - %5214 = torch.aten.size.int %5209, %int1_6505 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_6506 = torch.constant.int 32 - %5215 = torch.aten.mul.int %5214, %int32_6506 : !torch.int, !torch.int -> !torch.int - %int4_6507 = torch.constant.int 4 - %int8_6508 = torch.constant.int 8 - %int128_6509 = torch.constant.int 128 - %5216 = torch.prim.ListConstruct %int4_6507, %5215, %int8_6508, %int128_6509 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5217 = torch.aten._unsafe_view %5213, %5216 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5217, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_6510 = torch.constant.int 0 - %int0_6511 = torch.constant.int 0 - %int9223372036854775807_6512 = torch.constant.int 9223372036854775807 - %int1_6513 = torch.constant.int 1 - %5218 = torch.aten.slice.Tensor %5217, %int0_6510, %int0_6511, %int9223372036854775807_6512, %int1_6513 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5218, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_6514 = torch.constant.int 0 - %int0_6515 = torch.constant.int 0 - %int9223372036854775807_6516 = torch.constant.int 9223372036854775807 - %int1_6517 = torch.constant.int 1 - %5219 = torch.aten.slice.Tensor %5207, %int0_6514, %int0_6515, %int9223372036854775807_6516, %int1_6517 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5219, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> + %int14336_6466 = torch.constant.int 14336 + %6065 = torch.prim.ListConstruct %int4_6464, %int1_6465, %int14336_6466 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6066 = torch.aten.view %6064, %6065 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %6067 = torch.aten.mul.Tensor %6059, %6066 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_6467 = torch.constant.int -2 + %int-1_6468 = torch.constant.int -1 + %6068 = torch.aten.transpose.int %365, %int-2_6467, %int-1_6468 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_6469 = torch.constant.int 5 + %6069 = torch.prims.convert_element_type %6068, %int5_6469 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_6470 = torch.constant.int 4 + %int14336_6471 = torch.constant.int 14336 + %6070 = torch.prim.ListConstruct %int4_6470, %int14336_6471 : (!torch.int, !torch.int) -> !torch.list + %6071 = torch.aten.view %6067, %6070 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %6072 = torch.aten.mm %6071, %6069 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_6472 = torch.constant.int 4 + %int1_6473 = torch.constant.int 1 + %int4096_6474 = torch.constant.int 4096 + %6073 = torch.prim.ListConstruct %int4_6472, %int1_6473, %int4096_6474 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6074 = torch.aten.view %6072, %6073 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_6475 = torch.constant.int 1 + %6075 = torch.aten.add.Tensor %6041, %6074, %int1_6475 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_6476 = torch.constant.int 6 + %6076 = torch.prims.convert_element_type %6075, %int6_6476 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_6477 = torch.constant.int 2 + %6077 = torch.aten.pow.Tensor_Scalar %6076, %int2_6477 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_6478 = torch.constant.int -1 + %6078 = torch.prim.ListConstruct %int-1_6478 : (!torch.int) -> !torch.list + %true_6479 = torch.constant.bool true + %none_6480 = torch.constant.none + %6079 = torch.aten.mean.dim %6077, %6078, %true_6479, %none_6480 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_6481 = torch.constant.float 9.9999997473787516E-6 + %int1_6482 = torch.constant.int 1 + %6080 = torch.aten.add.Scalar %6079, %float9.999990e-06_6481, %int1_6482 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %6081 = torch.aten.rsqrt %6080 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %6082 = torch.aten.mul.Tensor %6076, %6081 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_6483 = torch.constant.int 5 + %6083 = torch.prims.convert_element_type %6082, %int5_6483 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %6084 = torch.aten.mul.Tensor %366, %6083 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_6484 = torch.constant.int 5 + %6085 = torch.prims.convert_element_type %6084, %int5_6484 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_6485 = torch.constant.int -2 + %int-1_6486 = torch.constant.int -1 + %6086 = torch.aten.transpose.int %367, %int-2_6485, %int-1_6486 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_6487 = torch.constant.int 5 + %6087 = torch.prims.convert_element_type %6086, %int5_6487 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_6488 = torch.constant.int 4 + %int4096_6489 = torch.constant.int 4096 + %6088 = torch.prim.ListConstruct %int4_6488, %int4096_6489 : (!torch.int, !torch.int) -> !torch.list + %6089 = torch.aten.view %6085, %6088 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6090 = torch.aten.mm %6089, %6087 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_6490 = torch.constant.int 4 + %int1_6491 = torch.constant.int 1 + %int4096_6492 = torch.constant.int 4096 + %6091 = torch.prim.ListConstruct %int4_6490, %int1_6491, %int4096_6492 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6092 = torch.aten.view %6090, %6091 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_6493 = torch.constant.int -2 + %int-1_6494 = torch.constant.int -1 + %6093 = torch.aten.transpose.int %368, %int-2_6493, %int-1_6494 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6495 = torch.constant.int 5 + %6094 = torch.prims.convert_element_type %6093, %int5_6495 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_6496 = torch.constant.int 4 + %int4096_6497 = torch.constant.int 4096 + %6095 = torch.prim.ListConstruct %int4_6496, %int4096_6497 : (!torch.int, !torch.int) -> !torch.list + %6096 = torch.aten.view %6085, %6095 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6097 = torch.aten.mm %6096, %6094 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_6498 = torch.constant.int 4 + %int1_6499 = torch.constant.int 1 + %int1024_6500 = torch.constant.int 1024 + %6098 = torch.prim.ListConstruct %int4_6498, %int1_6499, %int1024_6500 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6099 = torch.aten.view %6097, %6098 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_6501 = torch.constant.int -2 + %int-1_6502 = torch.constant.int -1 + %6100 = torch.aten.transpose.int %369, %int-2_6501, %int-1_6502 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6503 = torch.constant.int 5 + %6101 = torch.prims.convert_element_type %6100, %int5_6503 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_6504 = torch.constant.int 4 + %int4096_6505 = torch.constant.int 4096 + %6102 = torch.prim.ListConstruct %int4_6504, %int4096_6505 : (!torch.int, !torch.int) -> !torch.list + %6103 = torch.aten.view %6085, %6102 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6104 = torch.aten.mm %6103, %6101 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_6506 = torch.constant.int 4 + %int1_6507 = torch.constant.int 1 + %int1024_6508 = torch.constant.int 1024 + %6105 = torch.prim.ListConstruct %int4_6506, %int1_6507, %int1024_6508 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6106 = torch.aten.view %6104, %6105 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_6509 = torch.constant.int 4 + %int1_6510 = torch.constant.int 1 + %int32_6511 = torch.constant.int 32 + %int128_6512 = torch.constant.int 128 + %6107 = torch.prim.ListConstruct %int4_6509, %int1_6510, %int32_6511, %int128_6512 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6108 = torch.aten.view %6092, %6107 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_6513 = torch.constant.int 4 + %int1_6514 = torch.constant.int 1 + %int8_6515 = torch.constant.int 8 + %int128_6516 = torch.constant.int 128 + %6109 = torch.prim.ListConstruct %int4_6513, %int1_6514, %int8_6515, %int128_6516 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6110 = torch.aten.view %6099, %6109 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_6517 = torch.constant.int 4 %int1_6518 = torch.constant.int 1 - %int0_6519 = torch.constant.int 0 - %int9223372036854775807_6520 = torch.constant.int 9223372036854775807 + %int8_6519 = torch.constant.int 8 + %int128_6520 = torch.constant.int 128 + %6111 = torch.prim.ListConstruct %int4_6517, %int1_6518, %int8_6519, %int128_6520 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6112 = torch.aten.view %6106, %6111 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> %int1_6521 = torch.constant.int 1 - %5220 = torch.aten.slice.Tensor %5219, %int1_6518, %int0_6519, %int9223372036854775807_6520, %int1_6521 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5220, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> %int2_6522 = torch.constant.int 2 - %int1_6523 = torch.constant.int 1 - %5221 = torch.aten.select.int %5220, %int2_6522, %int1_6523 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5221, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_6524 = torch.constant.int 2 - %int0_6525 = torch.constant.int 0 + %6113 = torch.aten.transpose.int %6108, %int1_6521, %int2_6522 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %6114 = torch.aten.mul.Tensor %6113, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_6523 = torch.constant.int 3 + %int0_6524 = torch.constant.int 0 + %int64_6525 = torch.constant.int 64 %int1_6526 = torch.constant.int 1 - %5222 = torch.aten.slice.Tensor %5221, %int2_6524, %int0_6525, %5211, %int1_6526 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5222, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_6527 = torch.constant.int 0 - %5223 = torch.aten.clone %5222, %int0_6527 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5223, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_6528 = torch.constant.int 1 - %5224 = torch.aten.size.int %5220, %int1_6528 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_6529 = torch.constant.int 32 - %5225 = torch.aten.mul.int %5224, %int32_6529 : !torch.int, !torch.int -> !torch.int - %int4_6530 = torch.constant.int 4 - %int8_6531 = torch.constant.int 8 - %int128_6532 = torch.constant.int 128 - %5226 = torch.prim.ListConstruct %int4_6530, %5225, %int8_6531, %int128_6532 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5227 = torch.aten._unsafe_view %5223, %5226 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5227, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_6533 = torch.constant.int 0 - %int0_6534 = torch.constant.int 0 - %int9223372036854775807_6535 = torch.constant.int 9223372036854775807 - %int1_6536 = torch.constant.int 1 - %5228 = torch.aten.slice.Tensor %5227, %int0_6533, %int0_6534, %int9223372036854775807_6535, %int1_6536 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5228, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_6537 = torch.constant.int -2 - %5229 = torch.aten.unsqueeze %5218, %int-2_6537 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5229, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_6538 = torch.constant.int 1 - %5230 = torch.aten.size.int %5217, %int1_6538 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_6539 = torch.constant.int 4 - %int8_6540 = torch.constant.int 8 - %int4_6541 = torch.constant.int 4 - %int128_6542 = torch.constant.int 128 - %5231 = torch.prim.ListConstruct %int4_6539, %5230, %int8_6540, %int4_6541, %int128_6542 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_6543 = torch.constant.bool false - %5232 = torch.aten.expand %5229, %5231, %false_6543 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5232, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_6544 = torch.constant.int 0 - %5233 = torch.aten.clone %5232, %int0_6544 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5233, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_6545 = torch.constant.int 4 - %int32_6546 = torch.constant.int 32 - %int128_6547 = torch.constant.int 128 - %5234 = torch.prim.ListConstruct %int4_6545, %5230, %int32_6546, %int128_6547 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5235 = torch.aten._unsafe_view %5233, %5234 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5235, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_6548 = torch.constant.int -2 - %5236 = torch.aten.unsqueeze %5228, %int-2_6548 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5236, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_6549 = torch.constant.int 1 - %5237 = torch.aten.size.int %5227, %int1_6549 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_6550 = torch.constant.int 4 - %int8_6551 = torch.constant.int 8 - %int4_6552 = torch.constant.int 4 - %int128_6553 = torch.constant.int 128 - %5238 = torch.prim.ListConstruct %int4_6550, %5237, %int8_6551, %int4_6552, %int128_6553 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_6554 = torch.constant.bool false - %5239 = torch.aten.expand %5236, %5238, %false_6554 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5239, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_6555 = torch.constant.int 0 - %5240 = torch.aten.clone %5239, %int0_6555 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5240, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_6556 = torch.constant.int 4 - %int32_6557 = torch.constant.int 32 - %int128_6558 = torch.constant.int 128 - %5241 = torch.prim.ListConstruct %int4_6556, %5237, %int32_6557, %int128_6558 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5242 = torch.aten._unsafe_view %5240, %5241 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5242, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %6115 = torch.aten.slice.Tensor %6113, %int3_6523, %int0_6524, %int64_6525, %int1_6526 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_6527 = torch.constant.int 3 + %int64_6528 = torch.constant.int 64 + %int9223372036854775807_6529 = torch.constant.int 9223372036854775807 + %int1_6530 = torch.constant.int 1 + %6116 = torch.aten.slice.Tensor %6113, %int3_6527, %int64_6528, %int9223372036854775807_6529, %int1_6530 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %6117 = torch.aten.neg %6116 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %6118 = torch.prim.ListConstruct %6117, %6115 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_6531 = torch.constant.int -1 + %6119 = torch.aten.cat %6118, %int-1_6531 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %6120 = torch.aten.mul.Tensor %6119, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_6532 = torch.constant.int 1 + %6121 = torch.aten.add.Tensor %6114, %6120, %int1_6532 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_6533 = torch.constant.int 1 + %int2_6534 = torch.constant.int 2 + %6122 = torch.aten.transpose.int %6121, %int1_6533, %int2_6534 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_6535 = torch.constant.int 1 + %int2_6536 = torch.constant.int 2 + %6123 = torch.aten.transpose.int %6110, %int1_6535, %int2_6536 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %6124 = torch.aten.mul.Tensor %6123, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_6537 = torch.constant.int 3 + %int0_6538 = torch.constant.int 0 + %int64_6539 = torch.constant.int 64 + %int1_6540 = torch.constant.int 1 + %6125 = torch.aten.slice.Tensor %6123, %int3_6537, %int0_6538, %int64_6539, %int1_6540 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_6541 = torch.constant.int 3 + %int64_6542 = torch.constant.int 64 + %int9223372036854775807_6543 = torch.constant.int 9223372036854775807 + %int1_6544 = torch.constant.int 1 + %6126 = torch.aten.slice.Tensor %6123, %int3_6541, %int64_6542, %int9223372036854775807_6543, %int1_6544 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %6127 = torch.aten.neg %6126 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %6128 = torch.prim.ListConstruct %6127, %6125 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_6545 = torch.constant.int -1 + %6129 = torch.aten.cat %6128, %int-1_6545 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %6130 = torch.aten.mul.Tensor %6129, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_6546 = torch.constant.int 1 + %6131 = torch.aten.add.Tensor %6124, %6130, %int1_6546 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_6547 = torch.constant.int 1 + %int2_6548 = torch.constant.int 2 + %6132 = torch.aten.transpose.int %6131, %int1_6547, %int2_6548 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_6549 = torch.constant.int 32 + %6133 = torch.aten.floor_divide.Scalar %arg2, %int32_6549 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_6550 = torch.constant.int 1 + %6134 = torch.aten.unsqueeze %6133, %int1_6550 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_6551 = torch.constant.int 1 + %false_6552 = torch.constant.bool false + %6135 = torch.aten.gather %arg3, %int1_6551, %6134, %false_6552 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_6553 = torch.constant.int 4 + %int1_6554 = torch.constant.int 1 + %int1_6555 = torch.constant.int 1 + %6136 = torch.prim.ListConstruct %int4_6553, %int1_6554, %int1_6555 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6137 = torch.aten.view %6135, %6136 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_6556 = torch.constant.int 32 + %6138 = torch.aten.remainder.Scalar %arg2, %int32_6556 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_6557 = torch.constant.int 4 + %int1_6558 = torch.constant.int 1 %int1_6559 = torch.constant.int 1 - %int2_6560 = torch.constant.int 2 - %5243 = torch.aten.transpose.int %5123, %int1_6559, %int2_6560 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_6561 = torch.constant.int 1 - %int2_6562 = torch.constant.int 2 - %5244 = torch.aten.transpose.int %5235, %int1_6561, %int2_6562 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5244, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_6563 = torch.constant.int 1 - %int2_6564 = torch.constant.int 2 - %5245 = torch.aten.transpose.int %5242, %int1_6563, %int2_6564 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5245, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_6565 = torch.constant.float 0.000000e+00 - %false_6566 = torch.constant.bool false - %none_6567 = torch.constant.none - %5246:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5243, %5244, %5245, %float0.000000e00_6565, %false_6566, %368, %none_6567) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_6568 = torch.constant.int 1 - %int2_6569 = torch.constant.int 2 - %5247 = torch.aten.transpose.int %5246#0, %int1_6568, %int2_6569 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_6570 = torch.constant.int 4 + %6139 = torch.prim.ListConstruct %int4_6557, %int1_6558, %int1_6559 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6140 = torch.aten.view %6138, %6139 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_6560 = torch.constant.int 8 + %none_6561 = torch.constant.none + %none_6562 = torch.constant.none + %cpu_6563 = torch.constant.device "cpu" + %false_6564 = torch.constant.bool false + %6141 = torch.aten.arange %int8_6560, %none_6561, %none_6562, %cpu_6563, %false_6564 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_6565 = torch.constant.int 1 + %int1_6566 = torch.constant.int 1 + %int8_6567 = torch.constant.int 8 + %6142 = torch.prim.ListConstruct %int1_6565, %int1_6566, %int8_6567 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6143 = torch.aten.view %6141, %6142 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_6568 = torch.constant.none + %6144 = torch.aten.clone %370, %none_6568 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6145 = torch.aten.detach %6144 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6146 = torch.aten.detach %6145 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6147 = torch.aten.detach %6146 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_6569 = torch.constant.int 1 + %int1_6570 = torch.constant.int 1 %int1_6571 = torch.constant.int 1 - %int4096_6572 = torch.constant.int 4096 - %5248 = torch.prim.ListConstruct %int4_6570, %int1_6571, %int4096_6572 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5249 = torch.aten.view %5247, %5248 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_6573 = torch.constant.int -2 - %int-1_6574 = torch.constant.int -1 - %5250 = torch.aten.transpose.int %260, %int-2_6573, %int-1_6574 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6575 = torch.constant.int 4 - %int4096_6576 = torch.constant.int 4096 - %5251 = torch.prim.ListConstruct %int4_6575, %int4096_6576 : (!torch.int, !torch.int) -> !torch.list - %5252 = torch.aten.view %5249, %5251 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5253 = torch.aten.mm %5252, %5250 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_6577 = torch.constant.int 4 - %int1_6578 = torch.constant.int 1 - %int4096_6579 = torch.constant.int 4096 - %5254 = torch.prim.ListConstruct %int4_6577, %int1_6578, %int4096_6579 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5255 = torch.aten.view %5253, %5254 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_6580 = torch.constant.int 1 - %5256 = torch.aten.add.Tensor %5083, %5255, %int1_6580 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_6581 = torch.constant.int 6 - %5257 = torch.prims.convert_element_type %5256, %int6_6581 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %6148 = torch.prim.ListConstruct %int1_6569, %int1_6570, %int1_6571 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6149 = torch.aten.view %6147, %6148 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_6572 = torch.constant.int 32 + %6150 = torch.aten.mul.Scalar %6137, %int32_6572 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int26 = torch.constant.int 26 + %int1_6573 = torch.constant.int 1 + %6151 = torch.aten.add.Scalar %6150, %int26, %int1_6573 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_6574 = torch.constant.int 2 + %6152 = torch.aten.mul.Scalar %6151, %int2_6574 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_6575 = torch.constant.int 1 + %6153 = torch.aten.add.Tensor %6152, %6149, %int1_6575 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_6576 = torch.constant.int 8 + %6154 = torch.aten.mul.Scalar %6153, %int8_6576 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_6577 = torch.constant.int 1 + %6155 = torch.aten.add.Tensor %6154, %6143, %int1_6577 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_6578 = torch.constant.int 32 + %6156 = torch.aten.mul.Scalar %6155, %int32_6578 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_6579 = torch.constant.int 1 + %6157 = torch.aten.add.Tensor %6156, %6140, %int1_6579 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_6580 = torch.constant.int 5 + %6158 = torch.prims.convert_element_type %6132, %int5_6580 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_6581 = torch.constant.int 32 %int2_6582 = torch.constant.int 2 - %5258 = torch.aten.pow.Tensor_Scalar %5257, %int2_6582 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_6583 = torch.constant.int -1 - %5259 = torch.prim.ListConstruct %int-1_6583 : (!torch.int) -> !torch.list - %true_6584 = torch.constant.bool true - %none_6585 = torch.constant.none - %5260 = torch.aten.mean.dim %5258, %5259, %true_6584, %none_6585 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_6586 = torch.constant.float 9.9999997473787516E-6 - %int1_6587 = torch.constant.int 1 - %5261 = torch.aten.add.Scalar %5260, %float9.999990e-06_6586, %int1_6587 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %5262 = torch.aten.rsqrt %5261 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %5263 = torch.aten.mul.Tensor %5257, %5262 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_6588 = torch.constant.int 5 - %5264 = torch.prims.convert_element_type %5263, %int5_6588 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %5265 = torch.aten.mul.Tensor %261, %5264 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_6589 = torch.constant.int 5 - %5266 = torch.prims.convert_element_type %5265, %int5_6589 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_6590 = torch.constant.int -2 - %int-1_6591 = torch.constant.int -1 - %5267 = torch.aten.transpose.int %262, %int-2_6590, %int-1_6591 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6592 = torch.constant.int 4 - %int4096_6593 = torch.constant.int 4096 - %5268 = torch.prim.ListConstruct %int4_6592, %int4096_6593 : (!torch.int, !torch.int) -> !torch.list - %5269 = torch.aten.view %5266, %5268 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5270 = torch.aten.mm %5269, %5267 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_6594 = torch.constant.int 4 - %int1_6595 = torch.constant.int 1 - %int14336_6596 = torch.constant.int 14336 - %5271 = torch.prim.ListConstruct %int4_6594, %int1_6595, %int14336_6596 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5272 = torch.aten.view %5270, %5271 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %5273 = torch.aten.silu %5272 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_6597 = torch.constant.int -2 - %int-1_6598 = torch.constant.int -1 - %5274 = torch.aten.transpose.int %263, %int-2_6597, %int-1_6598 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6599 = torch.constant.int 4 - %int4096_6600 = torch.constant.int 4096 - %5275 = torch.prim.ListConstruct %int4_6599, %int4096_6600 : (!torch.int, !torch.int) -> !torch.list - %5276 = torch.aten.view %5266, %5275 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5277 = torch.aten.mm %5276, %5274 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_6601 = torch.constant.int 4 + %int8_6583 = torch.constant.int 8 + %int32_6584 = torch.constant.int 32 + %int128_6585 = torch.constant.int 128 + %6159 = torch.prim.ListConstruct %456, %int32_6581, %int2_6582, %int8_6583, %int32_6584, %int128_6585 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6160 = torch.aten.view %5980, %6159 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6160, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_6586 = torch.constant.int 128 + %6161 = torch.prim.ListConstruct %596, %int128_6586 : (!torch.int, !torch.int) -> !torch.list + %6162 = torch.aten.view %6160, %6161 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6162, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %6163 = torch.prim.ListConstruct %6157 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_6587 = torch.constant.bool false + %6164 = torch.aten.index_put %6162, %6163, %6158, %false_6587 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6164, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_6588 = torch.constant.int 32 + %int2_6589 = torch.constant.int 2 + %int8_6590 = torch.constant.int 8 + %int32_6591 = torch.constant.int 32 + %int128_6592 = torch.constant.int 128 + %6165 = torch.prim.ListConstruct %456, %int32_6588, %int2_6589, %int8_6590, %int32_6591, %int128_6592 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6166 = torch.aten.view %6164, %6165 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6166, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_6593 = torch.constant.int 2097152 + %6167 = torch.prim.ListConstruct %456, %int2097152_6593 : (!torch.int, !torch.int) -> !torch.list + %6168 = torch.aten.view %6166, %6167 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6168, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_6594 = torch.constant.int 32 + %int2_6595 = torch.constant.int 2 + %int8_6596 = torch.constant.int 8 + %int32_6597 = torch.constant.int 32 + %int128_6598 = torch.constant.int 128 + %6169 = torch.prim.ListConstruct %456, %int32_6594, %int2_6595, %int8_6596, %int32_6597, %int128_6598 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6170 = torch.aten.view %6168, %6169 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6170, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_6599 = torch.constant.int 128 + %6171 = torch.prim.ListConstruct %596, %int128_6599 : (!torch.int, !torch.int) -> !torch.list + %6172 = torch.aten.view %6170, %6171 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6172, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_6600 = torch.constant.none + %6173 = torch.aten.clone %371, %none_6600 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6174 = torch.aten.detach %6173 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6175 = torch.aten.detach %6174 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6176 = torch.aten.detach %6175 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_6601 = torch.constant.int 1 %int1_6602 = torch.constant.int 1 - %int14336_6603 = torch.constant.int 14336 - %5278 = torch.prim.ListConstruct %int4_6601, %int1_6602, %int14336_6603 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5279 = torch.aten.view %5277, %5278 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %5280 = torch.aten.mul.Tensor %5273, %5279 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_6604 = torch.constant.int -2 - %int-1_6605 = torch.constant.int -1 - %5281 = torch.aten.transpose.int %264, %int-2_6604, %int-1_6605 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_6606 = torch.constant.int 4 - %int14336_6607 = torch.constant.int 14336 - %5282 = torch.prim.ListConstruct %int4_6606, %int14336_6607 : (!torch.int, !torch.int) -> !torch.list - %5283 = torch.aten.view %5280, %5282 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %5284 = torch.aten.mm %5283, %5281 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_6608 = torch.constant.int 4 - %int1_6609 = torch.constant.int 1 - %int4096_6610 = torch.constant.int 4096 - %5285 = torch.prim.ListConstruct %int4_6608, %int1_6609, %int4096_6610 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5286 = torch.aten.view %5284, %5285 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_6611 = torch.constant.int 1 - %5287 = torch.aten.add.Tensor %5256, %5286, %int1_6611 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_6612 = torch.constant.int 6 - %5288 = torch.prims.convert_element_type %5287, %int6_6612 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_6613 = torch.constant.int 2 - %5289 = torch.aten.pow.Tensor_Scalar %5288, %int2_6613 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_6614 = torch.constant.int -1 - %5290 = torch.prim.ListConstruct %int-1_6614 : (!torch.int) -> !torch.list - %true_6615 = torch.constant.bool true - %none_6616 = torch.constant.none - %5291 = torch.aten.mean.dim %5289, %5290, %true_6615, %none_6616 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_6617 = torch.constant.float 9.9999997473787516E-6 - %int1_6618 = torch.constant.int 1 - %5292 = torch.aten.add.Scalar %5291, %float9.999990e-06_6617, %int1_6618 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %5293 = torch.aten.rsqrt %5292 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %5294 = torch.aten.mul.Tensor %5288, %5293 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_6619 = torch.constant.int 5 - %5295 = torch.prims.convert_element_type %5294, %int5_6619 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %5296 = torch.aten.mul.Tensor %265, %5295 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_6620 = torch.constant.int 5 - %5297 = torch.prims.convert_element_type %5296, %int5_6620 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_6621 = torch.constant.int -2 - %int-1_6622 = torch.constant.int -1 - %5298 = torch.aten.transpose.int %266, %int-2_6621, %int-1_6622 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6623 = torch.constant.int 4 - %int4096_6624 = torch.constant.int 4096 - %5299 = torch.prim.ListConstruct %int4_6623, %int4096_6624 : (!torch.int, !torch.int) -> !torch.list - %5300 = torch.aten.view %5297, %5299 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5301 = torch.aten.mm %5300, %5298 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_6625 = torch.constant.int 4 - %int1_6626 = torch.constant.int 1 - %int4096_6627 = torch.constant.int 4096 - %5302 = torch.prim.ListConstruct %int4_6625, %int1_6626, %int4096_6627 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5303 = torch.aten.view %5301, %5302 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_6628 = torch.constant.int -2 - %int-1_6629 = torch.constant.int -1 - %5304 = torch.aten.transpose.int %267, %int-2_6628, %int-1_6629 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6630 = torch.constant.int 4 - %int4096_6631 = torch.constant.int 4096 - %5305 = torch.prim.ListConstruct %int4_6630, %int4096_6631 : (!torch.int, !torch.int) -> !torch.list - %5306 = torch.aten.view %5297, %5305 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5307 = torch.aten.mm %5306, %5304 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_6632 = torch.constant.int 4 - %int1_6633 = torch.constant.int 1 - %int1024_6634 = torch.constant.int 1024 - %5308 = torch.prim.ListConstruct %int4_6632, %int1_6633, %int1024_6634 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5309 = torch.aten.view %5307, %5308 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_6635 = torch.constant.int -2 - %int-1_6636 = torch.constant.int -1 - %5310 = torch.aten.transpose.int %268, %int-2_6635, %int-1_6636 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6637 = torch.constant.int 4 - %int4096_6638 = torch.constant.int 4096 - %5311 = torch.prim.ListConstruct %int4_6637, %int4096_6638 : (!torch.int, !torch.int) -> !torch.list - %5312 = torch.aten.view %5297, %5311 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5313 = torch.aten.mm %5312, %5310 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_6639 = torch.constant.int 4 - %int1_6640 = torch.constant.int 1 - %int1024_6641 = torch.constant.int 1024 - %5314 = torch.prim.ListConstruct %int4_6639, %int1_6640, %int1024_6641 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5315 = torch.aten.view %5313, %5314 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int1_6603 = torch.constant.int 1 + %6177 = torch.prim.ListConstruct %int1_6601, %int1_6602, %int1_6603 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6178 = torch.aten.view %6176, %6177 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_6604 = torch.constant.int 32 + %6179 = torch.aten.mul.Scalar %6137, %int32_6604 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int26_6605 = torch.constant.int 26 + %int1_6606 = torch.constant.int 1 + %6180 = torch.aten.add.Scalar %6179, %int26_6605, %int1_6606 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_6607 = torch.constant.int 2 + %6181 = torch.aten.mul.Scalar %6180, %int2_6607 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_6608 = torch.constant.int 1 + %6182 = torch.aten.add.Tensor %6181, %6178, %int1_6608 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_6609 = torch.constant.int 8 + %6183 = torch.aten.mul.Scalar %6182, %int8_6609 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_6610 = torch.constant.int 1 + %6184 = torch.aten.add.Tensor %6183, %6143, %int1_6610 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_6611 = torch.constant.int 32 + %6185 = torch.aten.mul.Scalar %6184, %int32_6611 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_6612 = torch.constant.int 1 + %6186 = torch.aten.add.Tensor %6185, %6140, %int1_6612 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_6613 = torch.constant.int 5 + %6187 = torch.prims.convert_element_type %6112, %int5_6613 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %6188 = torch.prim.ListConstruct %6186 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_6614 = torch.constant.bool false + %6189 = torch.aten.index_put %6172, %6188, %6187, %false_6614 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6189, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_6615 = torch.constant.int 32 + %int2_6616 = torch.constant.int 2 + %int8_6617 = torch.constant.int 8 + %int32_6618 = torch.constant.int 32 + %int128_6619 = torch.constant.int 128 + %6190 = torch.prim.ListConstruct %456, %int32_6615, %int2_6616, %int8_6617, %int32_6618, %int128_6619 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6191 = torch.aten.view %6189, %6190 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6191, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_6620 = torch.constant.int 2097152 + %6192 = torch.prim.ListConstruct %456, %int2097152_6620 : (!torch.int, !torch.int) -> !torch.list + %6193 = torch.aten.view %6191, %6192 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6193, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_6621 = torch.constant.none + %6194 = torch.aten.clone %372, %none_6621 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6195 = torch.aten.detach %6194 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6196 = torch.aten.detach %6195 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6197 = torch.aten.detach %6196 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_6622 = torch.constant.none + %6198 = torch.aten.clone %373, %none_6622 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6199 = torch.aten.detach %6198 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6200 = torch.aten.detach %6199 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6201 = torch.aten.detach %6200 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_6623 = torch.constant.none + %6202 = torch.aten.clone %374, %none_6623 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6203 = torch.aten.detach %6202 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6204 = torch.aten.detach %6203 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6205 = torch.aten.detach %6204 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_6624 = torch.constant.int 32 + %int2_6625 = torch.constant.int 2 + %int8_6626 = torch.constant.int 8 + %int32_6627 = torch.constant.int 32 + %int128_6628 = torch.constant.int 128 + %6206 = torch.prim.ListConstruct %456, %int32_6624, %int2_6625, %int8_6626, %int32_6627, %int128_6628 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6207 = torch.aten.view %6193, %6206 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6207, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %6208 = torch_c.to_builtin_tensor %6207 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %6209 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_6629 = tensor.cast %6209 : tensor<4x?xi64> to tensor + %6210 = torch_c.to_builtin_tensor %6197 : !torch.vtensor<[],si64> -> tensor + %6211 = torch_c.to_builtin_tensor %6201 : !torch.vtensor<[],si64> -> tensor + %6212 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%6208, %cast_6629, %6210, %6211) : (tensor, tensor, tensor, tensor) -> tensor + %cast_6630 = tensor.cast %6212 : tensor to tensor<4x?x8x32x128xf16> + %6213 = torch_c.from_builtin_tensor %cast_6630 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %6213, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %6214 = torch_c.to_builtin_tensor %6207 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %6215 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_6631 = tensor.cast %6215 : tensor<4x?xi64> to tensor + %6216 = torch_c.to_builtin_tensor %6197 : !torch.vtensor<[],si64> -> tensor + %6217 = torch_c.to_builtin_tensor %6205 : !torch.vtensor<[],si64> -> tensor + %6218 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%6214, %cast_6631, %6216, %6217) : (tensor, tensor, tensor, tensor) -> tensor + %cast_6632 = tensor.cast %6218 : tensor to tensor<4x?x8x32x128xf16> + %6219 = torch_c.from_builtin_tensor %cast_6632 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %6219, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_6633 = torch.constant.int 2 + %int3_6634 = torch.constant.int 3 + %6220 = torch.aten.transpose.int %6213, %int2_6633, %int3_6634 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6220, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_6635 = torch.constant.int 0 + %6221 = torch.aten.clone %6220, %int0_6635 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6221, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_6636 = torch.constant.int 4 + %int8_6637 = torch.constant.int 8 + %int128_6638 = torch.constant.int 128 + %6222 = torch.prim.ListConstruct %int4_6636, %457, %int8_6637, %int128_6638 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6223 = torch.aten._unsafe_view %6221, %6222 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6223, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_6639 = torch.constant.int 2 + %int3_6640 = torch.constant.int 3 + %6224 = torch.aten.transpose.int %6219, %int2_6639, %int3_6640 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6224, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_6641 = torch.constant.int 0 + %6225 = torch.aten.clone %6224, %int0_6641 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6225, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int4_6642 = torch.constant.int 4 - %int1_6643 = torch.constant.int 1 - %int32_6644 = torch.constant.int 32 - %int128_6645 = torch.constant.int 128 - %5316 = torch.prim.ListConstruct %int4_6642, %int1_6643, %int32_6644, %int128_6645 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5317 = torch.aten.view %5303, %5316 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int8_6643 = torch.constant.int 8 + %int128_6644 = torch.constant.int 128 + %6226 = torch.prim.ListConstruct %int4_6642, %457, %int8_6643, %int128_6644 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6227 = torch.aten._unsafe_view %6225, %6226 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6227, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_6645 = torch.constant.int -2 + %6228 = torch.aten.unsqueeze %6223, %int-2_6645 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6228, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> %int4_6646 = torch.constant.int 4 - %int1_6647 = torch.constant.int 1 - %int8_6648 = torch.constant.int 8 + %int8_6647 = torch.constant.int 8 + %int4_6648 = torch.constant.int 4 %int128_6649 = torch.constant.int 128 - %5318 = torch.prim.ListConstruct %int4_6646, %int1_6647, %int8_6648, %int128_6649 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5319 = torch.aten.view %5309, %5318 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_6650 = torch.constant.int 4 - %int1_6651 = torch.constant.int 1 - %int8_6652 = torch.constant.int 8 - %int128_6653 = torch.constant.int 128 - %5320 = torch.prim.ListConstruct %int4_6650, %int1_6651, %int8_6652, %int128_6653 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5321 = torch.aten.view %5315, %5320 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_6654 = torch.constant.int 6 - %5322 = torch.prims.convert_element_type %5317, %int6_6654 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %5323 = torch_c.to_builtin_tensor %5322 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %5324 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %5325 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%5323, %5324) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %5326 = torch_c.from_builtin_tensor %5325 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_6655 = torch.constant.int 5 - %5327 = torch.prims.convert_element_type %5326, %int5_6655 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_6656 = torch.constant.int 6 - %5328 = torch.prims.convert_element_type %5319, %int6_6656 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %5329 = torch_c.to_builtin_tensor %5328 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %5330 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %5331 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%5329, %5330) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %5332 = torch_c.from_builtin_tensor %5331 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_6657 = torch.constant.int 5 - %5333 = torch.prims.convert_element_type %5332, %int5_6657 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_6658 = torch.constant.int 32 - %5334 = torch.aten.floor_divide.Scalar %arg2, %int32_6658 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_6659 = torch.constant.int 1 - %5335 = torch.aten.unsqueeze %5334, %int1_6659 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6660 = torch.constant.int 1 - %false_6661 = torch.constant.bool false - %5336 = torch.aten.gather %arg3, %int1_6660, %5335, %false_6661 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_6662 = torch.constant.int 32 - %5337 = torch.aten.remainder.Scalar %arg2, %int32_6662 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_6663 = torch.constant.int 1 - %5338 = torch.aten.unsqueeze %5337, %int1_6663 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_6664 = torch.constant.none - %5339 = torch.aten.clone %269, %none_6664 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_6665 = torch.constant.int 0 - %5340 = torch.aten.unsqueeze %5339, %int0_6665 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_6666 = torch.constant.int 4 + %6229 = torch.prim.ListConstruct %int4_6646, %457, %int8_6647, %int4_6648, %int128_6649 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_6650 = torch.constant.bool false + %6230 = torch.aten.expand %6228, %6229, %false_6650 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6230, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_6651 = torch.constant.int 0 + %6231 = torch.aten.clone %6230, %int0_6651 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6231, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_6652 = torch.constant.int 4 + %int32_6653 = torch.constant.int 32 + %int128_6654 = torch.constant.int 128 + %6232 = torch.prim.ListConstruct %int4_6652, %457, %int32_6653, %int128_6654 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6233 = torch.aten._unsafe_view %6231, %6232 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6233, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_6655 = torch.constant.int -2 + %6234 = torch.aten.unsqueeze %6227, %int-2_6655 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6234, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_6656 = torch.constant.int 4 + %int8_6657 = torch.constant.int 8 + %int4_6658 = torch.constant.int 4 + %int128_6659 = torch.constant.int 128 + %6235 = torch.prim.ListConstruct %int4_6656, %457, %int8_6657, %int4_6658, %int128_6659 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_6660 = torch.constant.bool false + %6236 = torch.aten.expand %6234, %6235, %false_6660 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6236, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_6661 = torch.constant.int 0 + %6237 = torch.aten.clone %6236, %int0_6661 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6237, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_6662 = torch.constant.int 4 + %int32_6663 = torch.constant.int 32 + %int128_6664 = torch.constant.int 128 + %6238 = torch.prim.ListConstruct %int4_6662, %457, %int32_6663, %int128_6664 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6239 = torch.aten._unsafe_view %6237, %6238 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6239, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_6665 = torch.constant.int 1 + %int2_6666 = torch.constant.int 2 + %6240 = torch.aten.transpose.int %6122, %int1_6665, %int2_6666 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_6667 = torch.constant.int 1 - %5341 = torch.prim.ListConstruct %int4_6666, %int1_6667 : (!torch.int, !torch.int) -> !torch.list - %int1_6668 = torch.constant.int 1 + %int2_6668 = torch.constant.int 2 + %6241 = torch.aten.transpose.int %6233, %int1_6667, %int2_6668 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6241, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> %int1_6669 = torch.constant.int 1 - %5342 = torch.prim.ListConstruct %int1_6668, %int1_6669 : (!torch.int, !torch.int) -> !torch.list - %int4_6670 = torch.constant.int 4 - %int0_6671 = torch.constant.int 0 - %cpu_6672 = torch.constant.device "cpu" - %false_6673 = torch.constant.bool false - %5343 = torch.aten.empty_strided %5341, %5342, %int4_6670, %int0_6671, %cpu_6672, %false_6673 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int24 = torch.constant.int 24 - %5344 = torch.aten.fill.Scalar %5343, %int24 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_6674 = torch.constant.int 4 - %int1_6675 = torch.constant.int 1 - %5345 = torch.prim.ListConstruct %int4_6674, %int1_6675 : (!torch.int, !torch.int) -> !torch.list - %5346 = torch.aten.repeat %5340, %5345 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_6676 = torch.constant.int 32 - %5347 = torch.aten.mul.Scalar %5336, %int32_6676 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int2_6670 = torch.constant.int 2 + %6242 = torch.aten.transpose.int %6239, %int1_6669, %int2_6670 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6242, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_6671 = torch.constant.float 0.000000e+00 + %false_6672 = torch.constant.bool false + %none_6673 = torch.constant.none + %6243:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6240, %6241, %6242, %float0.000000e00_6671, %false_6672, %470, %none_6673) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_6674 = torch.constant.int 1 + %int2_6675 = torch.constant.int 2 + %6244 = torch.aten.transpose.int %6243#0, %int1_6674, %int2_6675 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_6676 = torch.constant.int 4 %int1_6677 = torch.constant.int 1 - %5348 = torch.aten.add.Tensor %5347, %5344, %int1_6677 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_6678 = torch.constant.int 2 - %5349 = torch.aten.mul.Scalar %5348, %int2_6678 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6679 = torch.constant.int 1 - %5350 = torch.aten.add.Tensor %5349, %5346, %int1_6679 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_6680 = torch.constant.int 32 - %5351 = torch.aten.mul.Scalar %5350, %int32_6680 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6681 = torch.constant.int 1 - %5352 = torch.aten.add.Tensor %5351, %5338, %int1_6681 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_6682 = torch.constant.int 32 - %int2_6683 = torch.constant.int 2 - %int32_6684 = torch.constant.int 32 - %int8_6685 = torch.constant.int 8 - %int128_6686 = torch.constant.int 128 - %5353 = torch.prim.ListConstruct %437, %int32_6682, %int2_6683, %int32_6684, %int8_6685, %int128_6686 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5354 = torch.aten.view %5190, %5353 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5354, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_6687 = torch.constant.int 32 - %5355 = torch.aten.mul.int %437, %int32_6687 : !torch.int, !torch.int -> !torch.int - %int2_6688 = torch.constant.int 2 - %5356 = torch.aten.mul.int %5355, %int2_6688 : !torch.int, !torch.int -> !torch.int - %int32_6689 = torch.constant.int 32 - %5357 = torch.aten.mul.int %5356, %int32_6689 : !torch.int, !torch.int -> !torch.int - %int8_6690 = torch.constant.int 8 - %int128_6691 = torch.constant.int 128 - %5358 = torch.prim.ListConstruct %5357, %int8_6690, %int128_6691 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5359 = torch.aten.view %5354, %5358 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5359, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %5360 = torch.prim.ListConstruct %5352 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_6692 = torch.constant.bool false - %5361 = torch.aten.index_put %5359, %5360, %5333, %false_6692 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5361, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_6693 = torch.constant.int 32 - %int2_6694 = torch.constant.int 2 - %int32_6695 = torch.constant.int 32 - %int8_6696 = torch.constant.int 8 - %int128_6697 = torch.constant.int 128 - %5362 = torch.prim.ListConstruct %437, %int32_6693, %int2_6694, %int32_6695, %int8_6696, %int128_6697 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5363 = torch.aten.view %5361, %5362 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5363, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6698 = torch.constant.int 2097152 - %5364 = torch.prim.ListConstruct %437, %int2097152_6698 : (!torch.int, !torch.int) -> !torch.list - %5365 = torch.aten.view %5363, %5364 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5365, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_6699 = torch.constant.int 32 - %int2_6700 = torch.constant.int 2 - %int32_6701 = torch.constant.int 32 - %int8_6702 = torch.constant.int 8 - %int128_6703 = torch.constant.int 128 - %5366 = torch.prim.ListConstruct %437, %int32_6699, %int2_6700, %int32_6701, %int8_6702, %int128_6703 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5367 = torch.aten.view %5365, %5366 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5367, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_6704 = torch.constant.int 8 - %int128_6705 = torch.constant.int 128 - %5368 = torch.prim.ListConstruct %5357, %int8_6704, %int128_6705 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5369 = torch.aten.view %5367, %5368 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5369, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_6706 = torch.constant.int 32 - %5370 = torch.aten.floor_divide.Scalar %arg2, %int32_6706 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_6707 = torch.constant.int 1 - %5371 = torch.aten.unsqueeze %5370, %int1_6707 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6708 = torch.constant.int 1 - %false_6709 = torch.constant.bool false - %5372 = torch.aten.gather %arg3, %int1_6708, %5371, %false_6709 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_6710 = torch.constant.int 32 - %5373 = torch.aten.remainder.Scalar %arg2, %int32_6710 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4096_6678 = torch.constant.int 4096 + %6245 = torch.prim.ListConstruct %int4_6676, %int1_6677, %int4096_6678 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6246 = torch.aten.view %6244, %6245 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_6679 = torch.constant.int -2 + %int-1_6680 = torch.constant.int -1 + %6247 = torch.aten.transpose.int %375, %int-2_6679, %int-1_6680 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_6681 = torch.constant.int 5 + %6248 = torch.prims.convert_element_type %6247, %int5_6681 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_6682 = torch.constant.int 4 + %int4096_6683 = torch.constant.int 4096 + %6249 = torch.prim.ListConstruct %int4_6682, %int4096_6683 : (!torch.int, !torch.int) -> !torch.list + %6250 = torch.aten.view %6246, %6249 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6251 = torch.aten.mm %6250, %6248 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_6684 = torch.constant.int 4 + %int1_6685 = torch.constant.int 1 + %int4096_6686 = torch.constant.int 4096 + %6252 = torch.prim.ListConstruct %int4_6684, %int1_6685, %int4096_6686 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6253 = torch.aten.view %6251, %6252 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_6687 = torch.constant.int 1 + %6254 = torch.aten.add.Tensor %6075, %6253, %int1_6687 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_6688 = torch.constant.int 6 + %6255 = torch.prims.convert_element_type %6254, %int6_6688 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_6689 = torch.constant.int 2 + %6256 = torch.aten.pow.Tensor_Scalar %6255, %int2_6689 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_6690 = torch.constant.int -1 + %6257 = torch.prim.ListConstruct %int-1_6690 : (!torch.int) -> !torch.list + %true_6691 = torch.constant.bool true + %none_6692 = torch.constant.none + %6258 = torch.aten.mean.dim %6256, %6257, %true_6691, %none_6692 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_6693 = torch.constant.float 9.9999997473787516E-6 + %int1_6694 = torch.constant.int 1 + %6259 = torch.aten.add.Scalar %6258, %float9.999990e-06_6693, %int1_6694 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %6260 = torch.aten.rsqrt %6259 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %6261 = torch.aten.mul.Tensor %6255, %6260 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_6695 = torch.constant.int 5 + %6262 = torch.prims.convert_element_type %6261, %int5_6695 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %6263 = torch.aten.mul.Tensor %376, %6262 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_6696 = torch.constant.int 5 + %6264 = torch.prims.convert_element_type %6263, %int5_6696 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_6697 = torch.constant.int -2 + %int-1_6698 = torch.constant.int -1 + %6265 = torch.aten.transpose.int %377, %int-2_6697, %int-1_6698 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_6699 = torch.constant.int 5 + %6266 = torch.prims.convert_element_type %6265, %int5_6699 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_6700 = torch.constant.int 4 + %int4096_6701 = torch.constant.int 4096 + %6267 = torch.prim.ListConstruct %int4_6700, %int4096_6701 : (!torch.int, !torch.int) -> !torch.list + %6268 = torch.aten.view %6264, %6267 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6269 = torch.aten.mm %6268, %6266 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_6702 = torch.constant.int 4 + %int1_6703 = torch.constant.int 1 + %int14336_6704 = torch.constant.int 14336 + %6270 = torch.prim.ListConstruct %int4_6702, %int1_6703, %int14336_6704 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6271 = torch.aten.view %6269, %6270 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %6272 = torch.aten.silu %6271 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_6705 = torch.constant.int -2 + %int-1_6706 = torch.constant.int -1 + %6273 = torch.aten.transpose.int %378, %int-2_6705, %int-1_6706 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_6707 = torch.constant.int 5 + %6274 = torch.prims.convert_element_type %6273, %int5_6707 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_6708 = torch.constant.int 4 + %int4096_6709 = torch.constant.int 4096 + %6275 = torch.prim.ListConstruct %int4_6708, %int4096_6709 : (!torch.int, !torch.int) -> !torch.list + %6276 = torch.aten.view %6264, %6275 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6277 = torch.aten.mm %6276, %6274 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_6710 = torch.constant.int 4 %int1_6711 = torch.constant.int 1 - %5374 = torch.aten.unsqueeze %5373, %int1_6711 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_6712 = torch.constant.none - %5375 = torch.aten.clone %270, %none_6712 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_6713 = torch.constant.int 0 - %5376 = torch.aten.unsqueeze %5375, %int0_6713 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_6714 = torch.constant.int 4 - %int1_6715 = torch.constant.int 1 - %5377 = torch.prim.ListConstruct %int4_6714, %int1_6715 : (!torch.int, !torch.int) -> !torch.list - %int1_6716 = torch.constant.int 1 - %int1_6717 = torch.constant.int 1 - %5378 = torch.prim.ListConstruct %int1_6716, %int1_6717 : (!torch.int, !torch.int) -> !torch.list + %int14336_6712 = torch.constant.int 14336 + %6278 = torch.prim.ListConstruct %int4_6710, %int1_6711, %int14336_6712 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6279 = torch.aten.view %6277, %6278 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %6280 = torch.aten.mul.Tensor %6272, %6279 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_6713 = torch.constant.int -2 + %int-1_6714 = torch.constant.int -1 + %6281 = torch.aten.transpose.int %379, %int-2_6713, %int-1_6714 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_6715 = torch.constant.int 5 + %6282 = torch.prims.convert_element_type %6281, %int5_6715 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_6716 = torch.constant.int 4 + %int14336_6717 = torch.constant.int 14336 + %6283 = torch.prim.ListConstruct %int4_6716, %int14336_6717 : (!torch.int, !torch.int) -> !torch.list + %6284 = torch.aten.view %6280, %6283 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %6285 = torch.aten.mm %6284, %6282 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_6718 = torch.constant.int 4 - %int0_6719 = torch.constant.int 0 - %cpu_6720 = torch.constant.device "cpu" - %false_6721 = torch.constant.bool false - %5379 = torch.aten.empty_strided %5377, %5378, %int4_6718, %int0_6719, %cpu_6720, %false_6721 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int24_6722 = torch.constant.int 24 - %5380 = torch.aten.fill.Scalar %5379, %int24_6722 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_6723 = torch.constant.int 4 - %int1_6724 = torch.constant.int 1 - %5381 = torch.prim.ListConstruct %int4_6723, %int1_6724 : (!torch.int, !torch.int) -> !torch.list - %5382 = torch.aten.repeat %5376, %5381 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_6725 = torch.constant.int 32 - %5383 = torch.aten.mul.Scalar %5372, %int32_6725 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6726 = torch.constant.int 1 - %5384 = torch.aten.add.Tensor %5383, %5380, %int1_6726 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_6727 = torch.constant.int 2 - %5385 = torch.aten.mul.Scalar %5384, %int2_6727 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_6719 = torch.constant.int 1 + %int4096_6720 = torch.constant.int 4096 + %6286 = torch.prim.ListConstruct %int4_6718, %int1_6719, %int4096_6720 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6287 = torch.aten.view %6285, %6286 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_6721 = torch.constant.int 1 + %6288 = torch.aten.add.Tensor %6254, %6287, %int1_6721 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_6722 = torch.constant.int 6 + %6289 = torch.prims.convert_element_type %6288, %int6_6722 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_6723 = torch.constant.int 2 + %6290 = torch.aten.pow.Tensor_Scalar %6289, %int2_6723 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_6724 = torch.constant.int -1 + %6291 = torch.prim.ListConstruct %int-1_6724 : (!torch.int) -> !torch.list + %true_6725 = torch.constant.bool true + %none_6726 = torch.constant.none + %6292 = torch.aten.mean.dim %6290, %6291, %true_6725, %none_6726 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_6727 = torch.constant.float 9.9999997473787516E-6 %int1_6728 = torch.constant.int 1 - %5386 = torch.aten.add.Tensor %5385, %5382, %int1_6728 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_6729 = torch.constant.int 32 - %5387 = torch.aten.mul.Scalar %5386, %int32_6729 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6730 = torch.constant.int 1 - %5388 = torch.aten.add.Tensor %5387, %5374, %int1_6730 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %5389 = torch.prim.ListConstruct %5388 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_6731 = torch.constant.bool false - %5390 = torch.aten.index_put %5369, %5389, %5321, %false_6731 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5390, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_6732 = torch.constant.int 32 - %int2_6733 = torch.constant.int 2 - %int32_6734 = torch.constant.int 32 - %int8_6735 = torch.constant.int 8 - %int128_6736 = torch.constant.int 128 - %5391 = torch.prim.ListConstruct %437, %int32_6732, %int2_6733, %int32_6734, %int8_6735, %int128_6736 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5392 = torch.aten.view %5390, %5391 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5392, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6737 = torch.constant.int 2097152 - %5393 = torch.prim.ListConstruct %437, %int2097152_6737 : (!torch.int, !torch.int) -> !torch.list - %5394 = torch.aten.view %5392, %5393 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5394, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_6738 = torch.constant.int 4 - %5395 = torch.prim.ListConstruct %int4_6738, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_6739 = torch.constant.int 1 - %5396 = torch.prim.ListConstruct %358, %int1_6739 : (!torch.int, !torch.int) -> !torch.list - %int4_6740 = torch.constant.int 4 - %int0_6741 = torch.constant.int 0 - %cpu_6742 = torch.constant.device "cpu" - %false_6743 = torch.constant.bool false - %5397 = torch.aten.empty_strided %5395, %5396, %int4_6740, %int0_6741, %cpu_6742, %false_6743 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5397, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int24_6744 = torch.constant.int 24 - %5398 = torch.aten.fill.Scalar %5397, %int24_6744 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5398, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_6745 = torch.constant.int 32 - %5399 = torch.aten.mul.Scalar %arg3, %int32_6745 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5399, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_6746 = torch.constant.int 1 - %5400 = torch.aten.add.Tensor %5399, %5398, %int1_6746 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5400, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_6747 = torch.constant.int 4 - %5401 = torch.aten.mul.int %int4_6747, %358 : !torch.int, !torch.int -> !torch.int - %5402 = torch.prim.ListConstruct %5401 : (!torch.int) -> !torch.list - %5403 = torch.aten.view %5400, %5402 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %5403, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_6748 = torch.constant.int 32 - %int2_6749 = torch.constant.int 2 - %int32_6750 = torch.constant.int 32 - %int8_6751 = torch.constant.int 8 - %int128_6752 = torch.constant.int 128 - %5404 = torch.prim.ListConstruct %437, %int32_6748, %int2_6749, %int32_6750, %int8_6751, %int128_6752 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5405 = torch.aten.view %5394, %5404 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5405, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_6753 = torch.constant.int 32 - %5406 = torch.aten.mul.int %437, %int32_6753 : !torch.int, !torch.int -> !torch.int - %int2_6754 = torch.constant.int 2 - %int32_6755 = torch.constant.int 32 - %int8_6756 = torch.constant.int 8 - %int128_6757 = torch.constant.int 128 - %5407 = torch.prim.ListConstruct %5406, %int2_6754, %int32_6755, %int8_6756, %int128_6757 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5408 = torch.aten.view %5405, %5407 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %5408, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_6758 = torch.constant.int 0 - %5409 = torch.aten.index_select %5408, %int0_6758, %5403 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %5409, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> + %6293 = torch.aten.add.Scalar %6292, %float9.999990e-06_6727, %int1_6728 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %6294 = torch.aten.rsqrt %6293 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %6295 = torch.aten.mul.Tensor %6289, %6294 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_6729 = torch.constant.int 5 + %6296 = torch.prims.convert_element_type %6295, %int5_6729 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %6297 = torch.aten.mul.Tensor %380, %6296 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_6730 = torch.constant.int 5 + %6298 = torch.prims.convert_element_type %6297, %int5_6730 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_6731 = torch.constant.int -2 + %int-1_6732 = torch.constant.int -1 + %6299 = torch.aten.transpose.int %381, %int-2_6731, %int-1_6732 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_6733 = torch.constant.int 5 + %6300 = torch.prims.convert_element_type %6299, %int5_6733 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_6734 = torch.constant.int 4 + %int4096_6735 = torch.constant.int 4096 + %6301 = torch.prim.ListConstruct %int4_6734, %int4096_6735 : (!torch.int, !torch.int) -> !torch.list + %6302 = torch.aten.view %6298, %6301 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6303 = torch.aten.mm %6302, %6300 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_6736 = torch.constant.int 4 + %int1_6737 = torch.constant.int 1 + %int4096_6738 = torch.constant.int 4096 + %6304 = torch.prim.ListConstruct %int4_6736, %int1_6737, %int4096_6738 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6305 = torch.aten.view %6303, %6304 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_6739 = torch.constant.int -2 + %int-1_6740 = torch.constant.int -1 + %6306 = torch.aten.transpose.int %382, %int-2_6739, %int-1_6740 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6741 = torch.constant.int 5 + %6307 = torch.prims.convert_element_type %6306, %int5_6741 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_6742 = torch.constant.int 4 + %int4096_6743 = torch.constant.int 4096 + %6308 = torch.prim.ListConstruct %int4_6742, %int4096_6743 : (!torch.int, !torch.int) -> !torch.list + %6309 = torch.aten.view %6298, %6308 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6310 = torch.aten.mm %6309, %6307 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_6744 = torch.constant.int 4 + %int1_6745 = torch.constant.int 1 + %int1024_6746 = torch.constant.int 1024 + %6311 = torch.prim.ListConstruct %int4_6744, %int1_6745, %int1024_6746 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6312 = torch.aten.view %6310, %6311 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_6747 = torch.constant.int -2 + %int-1_6748 = torch.constant.int -1 + %6313 = torch.aten.transpose.int %383, %int-2_6747, %int-1_6748 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6749 = torch.constant.int 5 + %6314 = torch.prims.convert_element_type %6313, %int5_6749 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_6750 = torch.constant.int 4 + %int4096_6751 = torch.constant.int 4096 + %6315 = torch.prim.ListConstruct %int4_6750, %int4096_6751 : (!torch.int, !torch.int) -> !torch.list + %6316 = torch.aten.view %6298, %6315 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6317 = torch.aten.mm %6316, %6314 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_6752 = torch.constant.int 4 + %int1_6753 = torch.constant.int 1 + %int1024_6754 = torch.constant.int 1024 + %6318 = torch.prim.ListConstruct %int4_6752, %int1_6753, %int1024_6754 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6319 = torch.aten.view %6317, %6318 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_6755 = torch.constant.int 4 + %int1_6756 = torch.constant.int 1 + %int32_6757 = torch.constant.int 32 + %int128_6758 = torch.constant.int 128 + %6320 = torch.prim.ListConstruct %int4_6755, %int1_6756, %int32_6757, %int128_6758 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6321 = torch.aten.view %6305, %6320 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> %int4_6759 = torch.constant.int 4 - %int2_6760 = torch.constant.int 2 - %int32_6761 = torch.constant.int 32 - %int8_6762 = torch.constant.int 8 - %int128_6763 = torch.constant.int 128 - %5410 = torch.prim.ListConstruct %int4_6759, %358, %int2_6760, %int32_6761, %int8_6762, %int128_6763 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5411 = torch.aten.view %5409, %5410 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5411, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_6764 = torch.constant.int 0 - %int0_6765 = torch.constant.int 0 - %int9223372036854775807_6766 = torch.constant.int 9223372036854775807 + %int1_6760 = torch.constant.int 1 + %int8_6761 = torch.constant.int 8 + %int128_6762 = torch.constant.int 128 + %6322 = torch.prim.ListConstruct %int4_6759, %int1_6760, %int8_6761, %int128_6762 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6323 = torch.aten.view %6312, %6322 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_6763 = torch.constant.int 4 + %int1_6764 = torch.constant.int 1 + %int8_6765 = torch.constant.int 8 + %int128_6766 = torch.constant.int 128 + %6324 = torch.prim.ListConstruct %int4_6763, %int1_6764, %int8_6765, %int128_6766 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6325 = torch.aten.view %6319, %6324 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> %int1_6767 = torch.constant.int 1 - %5412 = torch.aten.slice.Tensor %5411, %int0_6764, %int0_6765, %int9223372036854775807_6766, %int1_6767 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5412, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_6768 = torch.constant.int 1 - %int0_6769 = torch.constant.int 0 - %int9223372036854775807_6770 = torch.constant.int 9223372036854775807 - %int1_6771 = torch.constant.int 1 - %5413 = torch.aten.slice.Tensor %5412, %int1_6768, %int0_6769, %int9223372036854775807_6770, %int1_6771 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5413, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_6772 = torch.constant.int 2 - %int0_6773 = torch.constant.int 0 - %5414 = torch.aten.select.int %5413, %int2_6772, %int0_6773 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5414, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_6774 = torch.constant.int 32 - %5415 = torch.aten.mul.int %358, %int32_6774 : !torch.int, !torch.int -> !torch.int - %int2_6775 = torch.constant.int 2 - %int0_6776 = torch.constant.int 0 - %int1_6777 = torch.constant.int 1 - %5416 = torch.aten.slice.Tensor %5414, %int2_6775, %int0_6776, %5415, %int1_6777 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5416, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_6778 = torch.constant.int 0 - %5417 = torch.aten.clone %5416, %int0_6778 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5417, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int2_6768 = torch.constant.int 2 + %6326 = torch.aten.transpose.int %6321, %int1_6767, %int2_6768 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %6327 = torch.aten.mul.Tensor %6326, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_6769 = torch.constant.int 3 + %int0_6770 = torch.constant.int 0 + %int64_6771 = torch.constant.int 64 + %int1_6772 = torch.constant.int 1 + %6328 = torch.aten.slice.Tensor %6326, %int3_6769, %int0_6770, %int64_6771, %int1_6772 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_6773 = torch.constant.int 3 + %int64_6774 = torch.constant.int 64 + %int9223372036854775807_6775 = torch.constant.int 9223372036854775807 + %int1_6776 = torch.constant.int 1 + %6329 = torch.aten.slice.Tensor %6326, %int3_6773, %int64_6774, %int9223372036854775807_6775, %int1_6776 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %6330 = torch.aten.neg %6329 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %6331 = torch.prim.ListConstruct %6330, %6328 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_6777 = torch.constant.int -1 + %6332 = torch.aten.cat %6331, %int-1_6777 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %6333 = torch.aten.mul.Tensor %6332, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_6778 = torch.constant.int 1 + %6334 = torch.aten.add.Tensor %6327, %6333, %int1_6778 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_6779 = torch.constant.int 1 - %5418 = torch.aten.size.int %5413, %int1_6779 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_6780 = torch.constant.int 32 - %5419 = torch.aten.mul.int %5418, %int32_6780 : !torch.int, !torch.int -> !torch.int - %int4_6781 = torch.constant.int 4 - %int8_6782 = torch.constant.int 8 - %int128_6783 = torch.constant.int 128 - %5420 = torch.prim.ListConstruct %int4_6781, %5419, %int8_6782, %int128_6783 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5421 = torch.aten._unsafe_view %5417, %5420 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5421, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_6780 = torch.constant.int 2 + %6335 = torch.aten.transpose.int %6334, %int1_6779, %int2_6780 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_6781 = torch.constant.int 1 + %int2_6782 = torch.constant.int 2 + %6336 = torch.aten.transpose.int %6323, %int1_6781, %int2_6782 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %6337 = torch.aten.mul.Tensor %6336, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_6783 = torch.constant.int 3 %int0_6784 = torch.constant.int 0 - %int0_6785 = torch.constant.int 0 - %int9223372036854775807_6786 = torch.constant.int 9223372036854775807 - %int1_6787 = torch.constant.int 1 - %5422 = torch.aten.slice.Tensor %5421, %int0_6784, %int0_6785, %int9223372036854775807_6786, %int1_6787 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5422, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_6788 = torch.constant.int 0 - %int0_6789 = torch.constant.int 0 - %int9223372036854775807_6790 = torch.constant.int 9223372036854775807 - %int1_6791 = torch.constant.int 1 - %5423 = torch.aten.slice.Tensor %5411, %int0_6788, %int0_6789, %int9223372036854775807_6790, %int1_6791 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5423, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> + %int64_6785 = torch.constant.int 64 + %int1_6786 = torch.constant.int 1 + %6338 = torch.aten.slice.Tensor %6336, %int3_6783, %int0_6784, %int64_6785, %int1_6786 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_6787 = torch.constant.int 3 + %int64_6788 = torch.constant.int 64 + %int9223372036854775807_6789 = torch.constant.int 9223372036854775807 + %int1_6790 = torch.constant.int 1 + %6339 = torch.aten.slice.Tensor %6336, %int3_6787, %int64_6788, %int9223372036854775807_6789, %int1_6790 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %6340 = torch.aten.neg %6339 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %6341 = torch.prim.ListConstruct %6340, %6338 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_6791 = torch.constant.int -1 + %6342 = torch.aten.cat %6341, %int-1_6791 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %6343 = torch.aten.mul.Tensor %6342, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> %int1_6792 = torch.constant.int 1 - %int0_6793 = torch.constant.int 0 - %int9223372036854775807_6794 = torch.constant.int 9223372036854775807 - %int1_6795 = torch.constant.int 1 - %5424 = torch.aten.slice.Tensor %5423, %int1_6792, %int0_6793, %int9223372036854775807_6794, %int1_6795 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5424, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_6796 = torch.constant.int 2 + %6344 = torch.aten.add.Tensor %6337, %6343, %int1_6792 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_6793 = torch.constant.int 1 + %int2_6794 = torch.constant.int 2 + %6345 = torch.aten.transpose.int %6344, %int1_6793, %int2_6794 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_6795 = torch.constant.int 32 + %6346 = torch.aten.floor_divide.Scalar %arg2, %int32_6795 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_6796 = torch.constant.int 1 + %6347 = torch.aten.unsqueeze %6346, %int1_6796 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> %int1_6797 = torch.constant.int 1 - %5425 = torch.aten.select.int %5424, %int2_6796, %int1_6797 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5425, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_6798 = torch.constant.int 2 - %int0_6799 = torch.constant.int 0 + %false_6798 = torch.constant.bool false + %6348 = torch.aten.gather %arg3, %int1_6797, %6347, %false_6798 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_6799 = torch.constant.int 4 %int1_6800 = torch.constant.int 1 - %5426 = torch.aten.slice.Tensor %5425, %int2_6798, %int0_6799, %5415, %int1_6800 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5426, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_6801 = torch.constant.int 0 - %5427 = torch.aten.clone %5426, %int0_6801 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5427, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_6802 = torch.constant.int 1 - %5428 = torch.aten.size.int %5424, %int1_6802 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_6803 = torch.constant.int 32 - %5429 = torch.aten.mul.int %5428, %int32_6803 : !torch.int, !torch.int -> !torch.int - %int4_6804 = torch.constant.int 4 - %int8_6805 = torch.constant.int 8 - %int128_6806 = torch.constant.int 128 - %5430 = torch.prim.ListConstruct %int4_6804, %5429, %int8_6805, %int128_6806 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5431 = torch.aten._unsafe_view %5427, %5430 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5431, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_6807 = torch.constant.int 0 - %int0_6808 = torch.constant.int 0 - %int9223372036854775807_6809 = torch.constant.int 9223372036854775807 - %int1_6810 = torch.constant.int 1 - %5432 = torch.aten.slice.Tensor %5431, %int0_6807, %int0_6808, %int9223372036854775807_6809, %int1_6810 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5432, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_6811 = torch.constant.int -2 - %5433 = torch.aten.unsqueeze %5422, %int-2_6811 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5433, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int1_6801 = torch.constant.int 1 + %6349 = torch.prim.ListConstruct %int4_6799, %int1_6800, %int1_6801 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6350 = torch.aten.view %6348, %6349 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_6802 = torch.constant.int 32 + %6351 = torch.aten.remainder.Scalar %arg2, %int32_6802 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_6803 = torch.constant.int 4 + %int1_6804 = torch.constant.int 1 + %int1_6805 = torch.constant.int 1 + %6352 = torch.prim.ListConstruct %int4_6803, %int1_6804, %int1_6805 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6353 = torch.aten.view %6351, %6352 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_6806 = torch.constant.int 8 + %none_6807 = torch.constant.none + %none_6808 = torch.constant.none + %cpu_6809 = torch.constant.device "cpu" + %false_6810 = torch.constant.bool false + %6354 = torch.aten.arange %int8_6806, %none_6807, %none_6808, %cpu_6809, %false_6810 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_6811 = torch.constant.int 1 %int1_6812 = torch.constant.int 1 - %5434 = torch.aten.size.int %5421, %int1_6812 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_6813 = torch.constant.int 4 - %int8_6814 = torch.constant.int 8 - %int4_6815 = torch.constant.int 4 - %int128_6816 = torch.constant.int 128 - %5435 = torch.prim.ListConstruct %int4_6813, %5434, %int8_6814, %int4_6815, %int128_6816 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_6817 = torch.constant.bool false - %5436 = torch.aten.expand %5433, %5435, %false_6817 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5436, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_6818 = torch.constant.int 0 - %5437 = torch.aten.clone %5436, %int0_6818 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5437, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_6819 = torch.constant.int 4 - %int32_6820 = torch.constant.int 32 - %int128_6821 = torch.constant.int 128 - %5438 = torch.prim.ListConstruct %int4_6819, %5434, %int32_6820, %int128_6821 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5439 = torch.aten._unsafe_view %5437, %5438 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5439, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_6822 = torch.constant.int -2 - %5440 = torch.aten.unsqueeze %5432, %int-2_6822 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5440, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int8_6813 = torch.constant.int 8 + %6355 = torch.prim.ListConstruct %int1_6811, %int1_6812, %int8_6813 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6356 = torch.aten.view %6354, %6355 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_6814 = torch.constant.none + %6357 = torch.aten.clone %384, %none_6814 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6358 = torch.aten.detach %6357 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6359 = torch.aten.detach %6358 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6360 = torch.aten.detach %6359 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_6815 = torch.constant.int 1 + %int1_6816 = torch.constant.int 1 + %int1_6817 = torch.constant.int 1 + %6361 = torch.prim.ListConstruct %int1_6815, %int1_6816, %int1_6817 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6362 = torch.aten.view %6360, %6361 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_6818 = torch.constant.int 32 + %6363 = torch.aten.mul.Scalar %6350, %int32_6818 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int27 = torch.constant.int 27 + %int1_6819 = torch.constant.int 1 + %6364 = torch.aten.add.Scalar %6363, %int27, %int1_6819 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_6820 = torch.constant.int 2 + %6365 = torch.aten.mul.Scalar %6364, %int2_6820 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_6821 = torch.constant.int 1 + %6366 = torch.aten.add.Tensor %6365, %6362, %int1_6821 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_6822 = torch.constant.int 8 + %6367 = torch.aten.mul.Scalar %6366, %int8_6822 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_6823 = torch.constant.int 1 - %5441 = torch.aten.size.int %5431, %int1_6823 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_6824 = torch.constant.int 4 - %int8_6825 = torch.constant.int 8 - %int4_6826 = torch.constant.int 4 - %int128_6827 = torch.constant.int 128 - %5442 = torch.prim.ListConstruct %int4_6824, %5441, %int8_6825, %int4_6826, %int128_6827 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_6828 = torch.constant.bool false - %5443 = torch.aten.expand %5440, %5442, %false_6828 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5443, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_6829 = torch.constant.int 0 - %5444 = torch.aten.clone %5443, %int0_6829 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5444, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_6830 = torch.constant.int 4 - %int32_6831 = torch.constant.int 32 + %6368 = torch.aten.add.Tensor %6367, %6356, %int1_6823 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_6824 = torch.constant.int 32 + %6369 = torch.aten.mul.Scalar %6368, %int32_6824 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_6825 = torch.constant.int 1 + %6370 = torch.aten.add.Tensor %6369, %6353, %int1_6825 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_6826 = torch.constant.int 5 + %6371 = torch.prims.convert_element_type %6345, %int5_6826 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_6827 = torch.constant.int 32 + %int2_6828 = torch.constant.int 2 + %int8_6829 = torch.constant.int 8 + %int32_6830 = torch.constant.int 32 + %int128_6831 = torch.constant.int 128 + %6372 = torch.prim.ListConstruct %456, %int32_6827, %int2_6828, %int8_6829, %int32_6830, %int128_6831 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6373 = torch.aten.view %6193, %6372 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6373, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> %int128_6832 = torch.constant.int 128 - %5445 = torch.prim.ListConstruct %int4_6830, %5441, %int32_6831, %int128_6832 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5446 = torch.aten._unsafe_view %5444, %5445 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5446, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_6833 = torch.constant.int 1 - %int2_6834 = torch.constant.int 2 - %5447 = torch.aten.transpose.int %5327, %int1_6833, %int2_6834 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_6835 = torch.constant.int 1 - %int2_6836 = torch.constant.int 2 - %5448 = torch.aten.transpose.int %5439, %int1_6835, %int2_6836 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5448, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_6837 = torch.constant.int 1 - %int2_6838 = torch.constant.int 2 - %5449 = torch.aten.transpose.int %5446, %int1_6837, %int2_6838 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5449, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_6839 = torch.constant.float 0.000000e+00 - %false_6840 = torch.constant.bool false - %none_6841 = torch.constant.none - %5450:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5447, %5448, %5449, %float0.000000e00_6839, %false_6840, %368, %none_6841) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_6842 = torch.constant.int 1 - %int2_6843 = torch.constant.int 2 - %5451 = torch.aten.transpose.int %5450#0, %int1_6842, %int2_6843 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_6844 = torch.constant.int 4 - %int1_6845 = torch.constant.int 1 - %int4096_6846 = torch.constant.int 4096 - %5452 = torch.prim.ListConstruct %int4_6844, %int1_6845, %int4096_6846 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5453 = torch.aten.view %5451, %5452 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_6847 = torch.constant.int -2 - %int-1_6848 = torch.constant.int -1 - %5454 = torch.aten.transpose.int %271, %int-2_6847, %int-1_6848 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6849 = torch.constant.int 4 - %int4096_6850 = torch.constant.int 4096 - %5455 = torch.prim.ListConstruct %int4_6849, %int4096_6850 : (!torch.int, !torch.int) -> !torch.list - %5456 = torch.aten.view %5453, %5455 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5457 = torch.aten.mm %5456, %5454 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_6851 = torch.constant.int 4 + %6374 = torch.prim.ListConstruct %596, %int128_6832 : (!torch.int, !torch.int) -> !torch.list + %6375 = torch.aten.view %6373, %6374 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6375, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %6376 = torch.prim.ListConstruct %6370 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_6833 = torch.constant.bool false + %6377 = torch.aten.index_put %6375, %6376, %6371, %false_6833 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6377, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_6834 = torch.constant.int 32 + %int2_6835 = torch.constant.int 2 + %int8_6836 = torch.constant.int 8 + %int32_6837 = torch.constant.int 32 + %int128_6838 = torch.constant.int 128 + %6378 = torch.prim.ListConstruct %456, %int32_6834, %int2_6835, %int8_6836, %int32_6837, %int128_6838 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6379 = torch.aten.view %6377, %6378 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6379, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_6839 = torch.constant.int 2097152 + %6380 = torch.prim.ListConstruct %456, %int2097152_6839 : (!torch.int, !torch.int) -> !torch.list + %6381 = torch.aten.view %6379, %6380 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6381, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_6840 = torch.constant.int 32 + %int2_6841 = torch.constant.int 2 + %int8_6842 = torch.constant.int 8 + %int32_6843 = torch.constant.int 32 + %int128_6844 = torch.constant.int 128 + %6382 = torch.prim.ListConstruct %456, %int32_6840, %int2_6841, %int8_6842, %int32_6843, %int128_6844 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6383 = torch.aten.view %6381, %6382 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6383, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_6845 = torch.constant.int 128 + %6384 = torch.prim.ListConstruct %596, %int128_6845 : (!torch.int, !torch.int) -> !torch.list + %6385 = torch.aten.view %6383, %6384 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6385, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_6846 = torch.constant.none + %6386 = torch.aten.clone %385, %none_6846 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6387 = torch.aten.detach %6386 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6388 = torch.aten.detach %6387 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6389 = torch.aten.detach %6388 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_6847 = torch.constant.int 1 + %int1_6848 = torch.constant.int 1 + %int1_6849 = torch.constant.int 1 + %6390 = torch.prim.ListConstruct %int1_6847, %int1_6848, %int1_6849 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6391 = torch.aten.view %6389, %6390 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_6850 = torch.constant.int 32 + %6392 = torch.aten.mul.Scalar %6350, %int32_6850 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int27_6851 = torch.constant.int 27 %int1_6852 = torch.constant.int 1 - %int4096_6853 = torch.constant.int 4096 - %5458 = torch.prim.ListConstruct %int4_6851, %int1_6852, %int4096_6853 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5459 = torch.aten.view %5457, %5458 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %6393 = torch.aten.add.Scalar %6392, %int27_6851, %int1_6852 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_6853 = torch.constant.int 2 + %6394 = torch.aten.mul.Scalar %6393, %int2_6853 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_6854 = torch.constant.int 1 - %5460 = torch.aten.add.Tensor %5287, %5459, %int1_6854 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_6855 = torch.constant.int 6 - %5461 = torch.prims.convert_element_type %5460, %int6_6855 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_6856 = torch.constant.int 2 - %5462 = torch.aten.pow.Tensor_Scalar %5461, %int2_6856 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_6857 = torch.constant.int -1 - %5463 = torch.prim.ListConstruct %int-1_6857 : (!torch.int) -> !torch.list - %true_6858 = torch.constant.bool true - %none_6859 = torch.constant.none - %5464 = torch.aten.mean.dim %5462, %5463, %true_6858, %none_6859 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_6860 = torch.constant.float 9.9999997473787516E-6 - %int1_6861 = torch.constant.int 1 - %5465 = torch.aten.add.Scalar %5464, %float9.999990e-06_6860, %int1_6861 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %5466 = torch.aten.rsqrt %5465 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %5467 = torch.aten.mul.Tensor %5461, %5466 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_6862 = torch.constant.int 5 - %5468 = torch.prims.convert_element_type %5467, %int5_6862 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %5469 = torch.aten.mul.Tensor %272, %5468 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_6863 = torch.constant.int 5 - %5470 = torch.prims.convert_element_type %5469, %int5_6863 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_6864 = torch.constant.int -2 - %int-1_6865 = torch.constant.int -1 - %5471 = torch.aten.transpose.int %273, %int-2_6864, %int-1_6865 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6866 = torch.constant.int 4 - %int4096_6867 = torch.constant.int 4096 - %5472 = torch.prim.ListConstruct %int4_6866, %int4096_6867 : (!torch.int, !torch.int) -> !torch.list - %5473 = torch.aten.view %5470, %5472 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5474 = torch.aten.mm %5473, %5471 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_6868 = torch.constant.int 4 - %int1_6869 = torch.constant.int 1 - %int14336_6870 = torch.constant.int 14336 - %5475 = torch.prim.ListConstruct %int4_6868, %int1_6869, %int14336_6870 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5476 = torch.aten.view %5474, %5475 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %5477 = torch.aten.silu %5476 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_6871 = torch.constant.int -2 - %int-1_6872 = torch.constant.int -1 - %5478 = torch.aten.transpose.int %274, %int-2_6871, %int-1_6872 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_6873 = torch.constant.int 4 - %int4096_6874 = torch.constant.int 4096 - %5479 = torch.prim.ListConstruct %int4_6873, %int4096_6874 : (!torch.int, !torch.int) -> !torch.list - %5480 = torch.aten.view %5470, %5479 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5481 = torch.aten.mm %5480, %5478 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_6875 = torch.constant.int 4 - %int1_6876 = torch.constant.int 1 - %int14336_6877 = torch.constant.int 14336 - %5482 = torch.prim.ListConstruct %int4_6875, %int1_6876, %int14336_6877 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5483 = torch.aten.view %5481, %5482 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %5484 = torch.aten.mul.Tensor %5477, %5483 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_6878 = torch.constant.int -2 - %int-1_6879 = torch.constant.int -1 - %5485 = torch.aten.transpose.int %275, %int-2_6878, %int-1_6879 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_6880 = torch.constant.int 4 - %int14336_6881 = torch.constant.int 14336 - %5486 = torch.prim.ListConstruct %int4_6880, %int14336_6881 : (!torch.int, !torch.int) -> !torch.list - %5487 = torch.aten.view %5484, %5486 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %5488 = torch.aten.mm %5487, %5485 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %6395 = torch.aten.add.Tensor %6394, %6391, %int1_6854 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_6855 = torch.constant.int 8 + %6396 = torch.aten.mul.Scalar %6395, %int8_6855 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_6856 = torch.constant.int 1 + %6397 = torch.aten.add.Tensor %6396, %6356, %int1_6856 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_6857 = torch.constant.int 32 + %6398 = torch.aten.mul.Scalar %6397, %int32_6857 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_6858 = torch.constant.int 1 + %6399 = torch.aten.add.Tensor %6398, %6353, %int1_6858 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_6859 = torch.constant.int 5 + %6400 = torch.prims.convert_element_type %6325, %int5_6859 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %6401 = torch.prim.ListConstruct %6399 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_6860 = torch.constant.bool false + %6402 = torch.aten.index_put %6385, %6401, %6400, %false_6860 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6402, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_6861 = torch.constant.int 32 + %int2_6862 = torch.constant.int 2 + %int8_6863 = torch.constant.int 8 + %int32_6864 = torch.constant.int 32 + %int128_6865 = torch.constant.int 128 + %6403 = torch.prim.ListConstruct %456, %int32_6861, %int2_6862, %int8_6863, %int32_6864, %int128_6865 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6404 = torch.aten.view %6402, %6403 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6404, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_6866 = torch.constant.int 2097152 + %6405 = torch.prim.ListConstruct %456, %int2097152_6866 : (!torch.int, !torch.int) -> !torch.list + %6406 = torch.aten.view %6404, %6405 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6406, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_6867 = torch.constant.none + %6407 = torch.aten.clone %386, %none_6867 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6408 = torch.aten.detach %6407 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6409 = torch.aten.detach %6408 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6410 = torch.aten.detach %6409 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_6868 = torch.constant.none + %6411 = torch.aten.clone %387, %none_6868 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6412 = torch.aten.detach %6411 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6413 = torch.aten.detach %6412 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6414 = torch.aten.detach %6413 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_6869 = torch.constant.none + %6415 = torch.aten.clone %388, %none_6869 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6416 = torch.aten.detach %6415 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6417 = torch.aten.detach %6416 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6418 = torch.aten.detach %6417 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_6870 = torch.constant.int 32 + %int2_6871 = torch.constant.int 2 + %int8_6872 = torch.constant.int 8 + %int32_6873 = torch.constant.int 32 + %int128_6874 = torch.constant.int 128 + %6419 = torch.prim.ListConstruct %456, %int32_6870, %int2_6871, %int8_6872, %int32_6873, %int128_6874 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6420 = torch.aten.view %6406, %6419 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6420, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %6421 = torch_c.to_builtin_tensor %6420 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %6422 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_6875 = tensor.cast %6422 : tensor<4x?xi64> to tensor + %6423 = torch_c.to_builtin_tensor %6410 : !torch.vtensor<[],si64> -> tensor + %6424 = torch_c.to_builtin_tensor %6414 : !torch.vtensor<[],si64> -> tensor + %6425 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%6421, %cast_6875, %6423, %6424) : (tensor, tensor, tensor, tensor) -> tensor + %cast_6876 = tensor.cast %6425 : tensor to tensor<4x?x8x32x128xf16> + %6426 = torch_c.from_builtin_tensor %cast_6876 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %6426, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %6427 = torch_c.to_builtin_tensor %6420 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %6428 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_6877 = tensor.cast %6428 : tensor<4x?xi64> to tensor + %6429 = torch_c.to_builtin_tensor %6410 : !torch.vtensor<[],si64> -> tensor + %6430 = torch_c.to_builtin_tensor %6418 : !torch.vtensor<[],si64> -> tensor + %6431 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%6427, %cast_6877, %6429, %6430) : (tensor, tensor, tensor, tensor) -> tensor + %cast_6878 = tensor.cast %6431 : tensor to tensor<4x?x8x32x128xf16> + %6432 = torch_c.from_builtin_tensor %cast_6878 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %6432, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_6879 = torch.constant.int 2 + %int3_6880 = torch.constant.int 3 + %6433 = torch.aten.transpose.int %6426, %int2_6879, %int3_6880 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6433, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_6881 = torch.constant.int 0 + %6434 = torch.aten.clone %6433, %int0_6881 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6434, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int4_6882 = torch.constant.int 4 - %int1_6883 = torch.constant.int 1 - %int4096_6884 = torch.constant.int 4096 - %5489 = torch.prim.ListConstruct %int4_6882, %int1_6883, %int4096_6884 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5490 = torch.aten.view %5488, %5489 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_6885 = torch.constant.int 1 - %5491 = torch.aten.add.Tensor %5460, %5490, %int1_6885 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_6886 = torch.constant.int 6 - %5492 = torch.prims.convert_element_type %5491, %int6_6886 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_6887 = torch.constant.int 2 - %5493 = torch.aten.pow.Tensor_Scalar %5492, %int2_6887 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_6888 = torch.constant.int -1 - %5494 = torch.prim.ListConstruct %int-1_6888 : (!torch.int) -> !torch.list - %true_6889 = torch.constant.bool true - %none_6890 = torch.constant.none - %5495 = torch.aten.mean.dim %5493, %5494, %true_6889, %none_6890 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_6891 = torch.constant.float 9.9999997473787516E-6 - %int1_6892 = torch.constant.int 1 - %5496 = torch.aten.add.Scalar %5495, %float9.999990e-06_6891, %int1_6892 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %5497 = torch.aten.rsqrt %5496 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %5498 = torch.aten.mul.Tensor %5492, %5497 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_6893 = torch.constant.int 5 - %5499 = torch.prims.convert_element_type %5498, %int5_6893 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %5500 = torch.aten.mul.Tensor %276, %5499 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_6894 = torch.constant.int 5 - %5501 = torch.prims.convert_element_type %5500, %int5_6894 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_6895 = torch.constant.int -2 - %int-1_6896 = torch.constant.int -1 - %5502 = torch.aten.transpose.int %277, %int-2_6895, %int-1_6896 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_6897 = torch.constant.int 4 - %int4096_6898 = torch.constant.int 4096 - %5503 = torch.prim.ListConstruct %int4_6897, %int4096_6898 : (!torch.int, !torch.int) -> !torch.list - %5504 = torch.aten.view %5501, %5503 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5505 = torch.aten.mm %5504, %5502 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_6899 = torch.constant.int 4 - %int1_6900 = torch.constant.int 1 - %int4096_6901 = torch.constant.int 4096 - %5506 = torch.prim.ListConstruct %int4_6899, %int1_6900, %int4096_6901 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5507 = torch.aten.view %5505, %5506 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_6902 = torch.constant.int -2 - %int-1_6903 = torch.constant.int -1 - %5508 = torch.aten.transpose.int %278, %int-2_6902, %int-1_6903 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int8_6883 = torch.constant.int 8 + %int128_6884 = torch.constant.int 128 + %6435 = torch.prim.ListConstruct %int4_6882, %457, %int8_6883, %int128_6884 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6436 = torch.aten._unsafe_view %6434, %6435 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6436, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_6885 = torch.constant.int 2 + %int3_6886 = torch.constant.int 3 + %6437 = torch.aten.transpose.int %6432, %int2_6885, %int3_6886 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6437, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_6887 = torch.constant.int 0 + %6438 = torch.aten.clone %6437, %int0_6887 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6438, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_6888 = torch.constant.int 4 + %int8_6889 = torch.constant.int 8 + %int128_6890 = torch.constant.int 128 + %6439 = torch.prim.ListConstruct %int4_6888, %457, %int8_6889, %int128_6890 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6440 = torch.aten._unsafe_view %6438, %6439 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6440, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_6891 = torch.constant.int -2 + %6441 = torch.aten.unsqueeze %6436, %int-2_6891 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6441, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_6892 = torch.constant.int 4 + %int8_6893 = torch.constant.int 8 + %int4_6894 = torch.constant.int 4 + %int128_6895 = torch.constant.int 128 + %6442 = torch.prim.ListConstruct %int4_6892, %457, %int8_6893, %int4_6894, %int128_6895 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_6896 = torch.constant.bool false + %6443 = torch.aten.expand %6441, %6442, %false_6896 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6443, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_6897 = torch.constant.int 0 + %6444 = torch.aten.clone %6443, %int0_6897 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6444, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_6898 = torch.constant.int 4 + %int32_6899 = torch.constant.int 32 + %int128_6900 = torch.constant.int 128 + %6445 = torch.prim.ListConstruct %int4_6898, %457, %int32_6899, %int128_6900 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6446 = torch.aten._unsafe_view %6444, %6445 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6446, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_6901 = torch.constant.int -2 + %6447 = torch.aten.unsqueeze %6440, %int-2_6901 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6447, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_6902 = torch.constant.int 4 + %int8_6903 = torch.constant.int 8 %int4_6904 = torch.constant.int 4 - %int4096_6905 = torch.constant.int 4096 - %5509 = torch.prim.ListConstruct %int4_6904, %int4096_6905 : (!torch.int, !torch.int) -> !torch.list - %5510 = torch.aten.view %5501, %5509 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5511 = torch.aten.mm %5510, %5508 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_6906 = torch.constant.int 4 - %int1_6907 = torch.constant.int 1 - %int1024_6908 = torch.constant.int 1024 - %5512 = torch.prim.ListConstruct %int4_6906, %int1_6907, %int1024_6908 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5513 = torch.aten.view %5511, %5512 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_6909 = torch.constant.int -2 - %int-1_6910 = torch.constant.int -1 - %5514 = torch.aten.transpose.int %279, %int-2_6909, %int-1_6910 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_6911 = torch.constant.int 4 - %int4096_6912 = torch.constant.int 4096 - %5515 = torch.prim.ListConstruct %int4_6911, %int4096_6912 : (!torch.int, !torch.int) -> !torch.list - %5516 = torch.aten.view %5501, %5515 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5517 = torch.aten.mm %5516, %5514 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_6913 = torch.constant.int 4 - %int1_6914 = torch.constant.int 1 - %int1024_6915 = torch.constant.int 1024 - %5518 = torch.prim.ListConstruct %int4_6913, %int1_6914, %int1024_6915 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5519 = torch.aten.view %5517, %5518 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_6916 = torch.constant.int 4 - %int1_6917 = torch.constant.int 1 - %int32_6918 = torch.constant.int 32 - %int128_6919 = torch.constant.int 128 - %5520 = torch.prim.ListConstruct %int4_6916, %int1_6917, %int32_6918, %int128_6919 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5521 = torch.aten.view %5507, %5520 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_6920 = torch.constant.int 4 - %int1_6921 = torch.constant.int 1 - %int8_6922 = torch.constant.int 8 - %int128_6923 = torch.constant.int 128 - %5522 = torch.prim.ListConstruct %int4_6920, %int1_6921, %int8_6922, %int128_6923 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5523 = torch.aten.view %5513, %5522 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_6924 = torch.constant.int 4 - %int1_6925 = torch.constant.int 1 - %int8_6926 = torch.constant.int 8 - %int128_6927 = torch.constant.int 128 - %5524 = torch.prim.ListConstruct %int4_6924, %int1_6925, %int8_6926, %int128_6927 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5525 = torch.aten.view %5519, %5524 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_6928 = torch.constant.int 6 - %5526 = torch.prims.convert_element_type %5521, %int6_6928 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %5527 = torch_c.to_builtin_tensor %5526 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %5528 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %5529 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%5527, %5528) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %5530 = torch_c.from_builtin_tensor %5529 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_6929 = torch.constant.int 5 - %5531 = torch.prims.convert_element_type %5530, %int5_6929 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_6930 = torch.constant.int 6 - %5532 = torch.prims.convert_element_type %5523, %int6_6930 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %5533 = torch_c.to_builtin_tensor %5532 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %5534 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %5535 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%5533, %5534) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %5536 = torch_c.from_builtin_tensor %5535 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_6931 = torch.constant.int 5 - %5537 = torch.prims.convert_element_type %5536, %int5_6931 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_6932 = torch.constant.int 32 - %5538 = torch.aten.floor_divide.Scalar %arg2, %int32_6932 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int128_6905 = torch.constant.int 128 + %6448 = torch.prim.ListConstruct %int4_6902, %457, %int8_6903, %int4_6904, %int128_6905 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_6906 = torch.constant.bool false + %6449 = torch.aten.expand %6447, %6448, %false_6906 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6449, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_6907 = torch.constant.int 0 + %6450 = torch.aten.clone %6449, %int0_6907 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6450, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_6908 = torch.constant.int 4 + %int32_6909 = torch.constant.int 32 + %int128_6910 = torch.constant.int 128 + %6451 = torch.prim.ListConstruct %int4_6908, %457, %int32_6909, %int128_6910 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6452 = torch.aten._unsafe_view %6450, %6451 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6452, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_6911 = torch.constant.int 1 + %int2_6912 = torch.constant.int 2 + %6453 = torch.aten.transpose.int %6335, %int1_6911, %int2_6912 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_6913 = torch.constant.int 1 + %int2_6914 = torch.constant.int 2 + %6454 = torch.aten.transpose.int %6446, %int1_6913, %int2_6914 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6454, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_6915 = torch.constant.int 1 + %int2_6916 = torch.constant.int 2 + %6455 = torch.aten.transpose.int %6452, %int1_6915, %int2_6916 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6455, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_6917 = torch.constant.float 0.000000e+00 + %false_6918 = torch.constant.bool false + %none_6919 = torch.constant.none + %6456:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6453, %6454, %6455, %float0.000000e00_6917, %false_6918, %470, %none_6919) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_6920 = torch.constant.int 1 + %int2_6921 = torch.constant.int 2 + %6457 = torch.aten.transpose.int %6456#0, %int1_6920, %int2_6921 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_6922 = torch.constant.int 4 + %int1_6923 = torch.constant.int 1 + %int4096_6924 = torch.constant.int 4096 + %6458 = torch.prim.ListConstruct %int4_6922, %int1_6923, %int4096_6924 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6459 = torch.aten.view %6457, %6458 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_6925 = torch.constant.int -2 + %int-1_6926 = torch.constant.int -1 + %6460 = torch.aten.transpose.int %389, %int-2_6925, %int-1_6926 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_6927 = torch.constant.int 5 + %6461 = torch.prims.convert_element_type %6460, %int5_6927 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_6928 = torch.constant.int 4 + %int4096_6929 = torch.constant.int 4096 + %6462 = torch.prim.ListConstruct %int4_6928, %int4096_6929 : (!torch.int, !torch.int) -> !torch.list + %6463 = torch.aten.view %6459, %6462 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6464 = torch.aten.mm %6463, %6461 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_6930 = torch.constant.int 4 + %int1_6931 = torch.constant.int 1 + %int4096_6932 = torch.constant.int 4096 + %6465 = torch.prim.ListConstruct %int4_6930, %int1_6931, %int4096_6932 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6466 = torch.aten.view %6464, %6465 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_6933 = torch.constant.int 1 - %5539 = torch.aten.unsqueeze %5538, %int1_6933 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6934 = torch.constant.int 1 - %false_6935 = torch.constant.bool false - %5540 = torch.aten.gather %arg3, %int1_6934, %5539, %false_6935 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_6936 = torch.constant.int 32 - %5541 = torch.aten.remainder.Scalar %arg2, %int32_6936 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_6937 = torch.constant.int 1 - %5542 = torch.aten.unsqueeze %5541, %int1_6937 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %6467 = torch.aten.add.Tensor %6288, %6466, %int1_6933 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_6934 = torch.constant.int 6 + %6468 = torch.prims.convert_element_type %6467, %int6_6934 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_6935 = torch.constant.int 2 + %6469 = torch.aten.pow.Tensor_Scalar %6468, %int2_6935 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_6936 = torch.constant.int -1 + %6470 = torch.prim.ListConstruct %int-1_6936 : (!torch.int) -> !torch.list + %true_6937 = torch.constant.bool true %none_6938 = torch.constant.none - %5543 = torch.aten.clone %280, %none_6938 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_6939 = torch.constant.int 0 - %5544 = torch.aten.unsqueeze %5543, %int0_6939 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_6940 = torch.constant.int 4 - %int1_6941 = torch.constant.int 1 - %5545 = torch.prim.ListConstruct %int4_6940, %int1_6941 : (!torch.int, !torch.int) -> !torch.list - %int1_6942 = torch.constant.int 1 - %int1_6943 = torch.constant.int 1 - %5546 = torch.prim.ListConstruct %int1_6942, %int1_6943 : (!torch.int, !torch.int) -> !torch.list - %int4_6944 = torch.constant.int 4 - %int0_6945 = torch.constant.int 0 - %cpu_6946 = torch.constant.device "cpu" - %false_6947 = torch.constant.bool false - %5547 = torch.aten.empty_strided %5545, %5546, %int4_6944, %int0_6945, %cpu_6946, %false_6947 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int25 = torch.constant.int 25 - %5548 = torch.aten.fill.Scalar %5547, %int25 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %6471 = torch.aten.mean.dim %6469, %6470, %true_6937, %none_6938 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_6939 = torch.constant.float 9.9999997473787516E-6 + %int1_6940 = torch.constant.int 1 + %6472 = torch.aten.add.Scalar %6471, %float9.999990e-06_6939, %int1_6940 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %6473 = torch.aten.rsqrt %6472 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %6474 = torch.aten.mul.Tensor %6468, %6473 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_6941 = torch.constant.int 5 + %6475 = torch.prims.convert_element_type %6474, %int5_6941 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %6476 = torch.aten.mul.Tensor %390, %6475 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_6942 = torch.constant.int 5 + %6477 = torch.prims.convert_element_type %6476, %int5_6942 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_6943 = torch.constant.int -2 + %int-1_6944 = torch.constant.int -1 + %6478 = torch.aten.transpose.int %391, %int-2_6943, %int-1_6944 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_6945 = torch.constant.int 5 + %6479 = torch.prims.convert_element_type %6478, %int5_6945 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_6946 = torch.constant.int 4 + %int4096_6947 = torch.constant.int 4096 + %6480 = torch.prim.ListConstruct %int4_6946, %int4096_6947 : (!torch.int, !torch.int) -> !torch.list + %6481 = torch.aten.view %6477, %6480 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6482 = torch.aten.mm %6481, %6479 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> %int4_6948 = torch.constant.int 4 %int1_6949 = torch.constant.int 1 - %5549 = torch.prim.ListConstruct %int4_6948, %int1_6949 : (!torch.int, !torch.int) -> !torch.list - %5550 = torch.aten.repeat %5544, %5549 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_6950 = torch.constant.int 32 - %5551 = torch.aten.mul.Scalar %5540, %int32_6950 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6951 = torch.constant.int 1 - %5552 = torch.aten.add.Tensor %5551, %5548, %int1_6951 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_6952 = torch.constant.int 2 - %5553 = torch.aten.mul.Scalar %5552, %int2_6952 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6953 = torch.constant.int 1 - %5554 = torch.aten.add.Tensor %5553, %5550, %int1_6953 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_6954 = torch.constant.int 32 - %5555 = torch.aten.mul.Scalar %5554, %int32_6954 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6955 = torch.constant.int 1 - %5556 = torch.aten.add.Tensor %5555, %5542, %int1_6955 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_6956 = torch.constant.int 32 - %int2_6957 = torch.constant.int 2 - %int32_6958 = torch.constant.int 32 - %int8_6959 = torch.constant.int 8 - %int128_6960 = torch.constant.int 128 - %5557 = torch.prim.ListConstruct %437, %int32_6956, %int2_6957, %int32_6958, %int8_6959, %int128_6960 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5558 = torch.aten.view %5394, %5557 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5558, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_6961 = torch.constant.int 32 - %5559 = torch.aten.mul.int %437, %int32_6961 : !torch.int, !torch.int -> !torch.int - %int2_6962 = torch.constant.int 2 - %5560 = torch.aten.mul.int %5559, %int2_6962 : !torch.int, !torch.int -> !torch.int - %int32_6963 = torch.constant.int 32 - %5561 = torch.aten.mul.int %5560, %int32_6963 : !torch.int, !torch.int -> !torch.int - %int8_6964 = torch.constant.int 8 - %int128_6965 = torch.constant.int 128 - %5562 = torch.prim.ListConstruct %5561, %int8_6964, %int128_6965 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5563 = torch.aten.view %5558, %5562 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5563, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %5564 = torch.prim.ListConstruct %5556 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_6966 = torch.constant.bool false - %5565 = torch.aten.index_put %5563, %5564, %5537, %false_6966 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5565, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_6967 = torch.constant.int 32 - %int2_6968 = torch.constant.int 2 - %int32_6969 = torch.constant.int 32 - %int8_6970 = torch.constant.int 8 - %int128_6971 = torch.constant.int 128 - %5566 = torch.prim.ListConstruct %437, %int32_6967, %int2_6968, %int32_6969, %int8_6970, %int128_6971 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5567 = torch.aten.view %5565, %5566 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5567, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_6972 = torch.constant.int 2097152 - %5568 = torch.prim.ListConstruct %437, %int2097152_6972 : (!torch.int, !torch.int) -> !torch.list - %5569 = torch.aten.view %5567, %5568 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5569, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_6973 = torch.constant.int 32 - %int2_6974 = torch.constant.int 2 - %int32_6975 = torch.constant.int 32 - %int8_6976 = torch.constant.int 8 - %int128_6977 = torch.constant.int 128 - %5570 = torch.prim.ListConstruct %437, %int32_6973, %int2_6974, %int32_6975, %int8_6976, %int128_6977 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5571 = torch.aten.view %5569, %5570 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5571, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_6978 = torch.constant.int 8 - %int128_6979 = torch.constant.int 128 - %5572 = torch.prim.ListConstruct %5561, %int8_6978, %int128_6979 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5573 = torch.aten.view %5571, %5572 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5573, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_6980 = torch.constant.int 32 - %5574 = torch.aten.floor_divide.Scalar %arg2, %int32_6980 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_6981 = torch.constant.int 1 - %5575 = torch.aten.unsqueeze %5574, %int1_6981 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_6982 = torch.constant.int 1 - %false_6983 = torch.constant.bool false - %5576 = torch.aten.gather %arg3, %int1_6982, %5575, %false_6983 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_6984 = torch.constant.int 32 - %5577 = torch.aten.remainder.Scalar %arg2, %int32_6984 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_6985 = torch.constant.int 1 - %5578 = torch.aten.unsqueeze %5577, %int1_6985 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_6986 = torch.constant.none - %5579 = torch.aten.clone %281, %none_6986 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_6987 = torch.constant.int 0 - %5580 = torch.aten.unsqueeze %5579, %int0_6987 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %int14336_6950 = torch.constant.int 14336 + %6483 = torch.prim.ListConstruct %int4_6948, %int1_6949, %int14336_6950 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6484 = torch.aten.view %6482, %6483 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %6485 = torch.aten.silu %6484 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_6951 = torch.constant.int -2 + %int-1_6952 = torch.constant.int -1 + %6486 = torch.aten.transpose.int %392, %int-2_6951, %int-1_6952 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_6953 = torch.constant.int 5 + %6487 = torch.prims.convert_element_type %6486, %int5_6953 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_6954 = torch.constant.int 4 + %int4096_6955 = torch.constant.int 4096 + %6488 = torch.prim.ListConstruct %int4_6954, %int4096_6955 : (!torch.int, !torch.int) -> !torch.list + %6489 = torch.aten.view %6477, %6488 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6490 = torch.aten.mm %6489, %6487 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_6956 = torch.constant.int 4 + %int1_6957 = torch.constant.int 1 + %int14336_6958 = torch.constant.int 14336 + %6491 = torch.prim.ListConstruct %int4_6956, %int1_6957, %int14336_6958 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6492 = torch.aten.view %6490, %6491 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %6493 = torch.aten.mul.Tensor %6485, %6492 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_6959 = torch.constant.int -2 + %int-1_6960 = torch.constant.int -1 + %6494 = torch.aten.transpose.int %393, %int-2_6959, %int-1_6960 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_6961 = torch.constant.int 5 + %6495 = torch.prims.convert_element_type %6494, %int5_6961 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_6962 = torch.constant.int 4 + %int14336_6963 = torch.constant.int 14336 + %6496 = torch.prim.ListConstruct %int4_6962, %int14336_6963 : (!torch.int, !torch.int) -> !torch.list + %6497 = torch.aten.view %6493, %6496 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %6498 = torch.aten.mm %6497, %6495 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_6964 = torch.constant.int 4 + %int1_6965 = torch.constant.int 1 + %int4096_6966 = torch.constant.int 4096 + %6499 = torch.prim.ListConstruct %int4_6964, %int1_6965, %int4096_6966 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6500 = torch.aten.view %6498, %6499 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_6967 = torch.constant.int 1 + %6501 = torch.aten.add.Tensor %6467, %6500, %int1_6967 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_6968 = torch.constant.int 6 + %6502 = torch.prims.convert_element_type %6501, %int6_6968 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_6969 = torch.constant.int 2 + %6503 = torch.aten.pow.Tensor_Scalar %6502, %int2_6969 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_6970 = torch.constant.int -1 + %6504 = torch.prim.ListConstruct %int-1_6970 : (!torch.int) -> !torch.list + %true_6971 = torch.constant.bool true + %none_6972 = torch.constant.none + %6505 = torch.aten.mean.dim %6503, %6504, %true_6971, %none_6972 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_6973 = torch.constant.float 9.9999997473787516E-6 + %int1_6974 = torch.constant.int 1 + %6506 = torch.aten.add.Scalar %6505, %float9.999990e-06_6973, %int1_6974 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %6507 = torch.aten.rsqrt %6506 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %6508 = torch.aten.mul.Tensor %6502, %6507 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_6975 = torch.constant.int 5 + %6509 = torch.prims.convert_element_type %6508, %int5_6975 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %6510 = torch.aten.mul.Tensor %394, %6509 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_6976 = torch.constant.int 5 + %6511 = torch.prims.convert_element_type %6510, %int5_6976 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_6977 = torch.constant.int -2 + %int-1_6978 = torch.constant.int -1 + %6512 = torch.aten.transpose.int %395, %int-2_6977, %int-1_6978 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_6979 = torch.constant.int 5 + %6513 = torch.prims.convert_element_type %6512, %int5_6979 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_6980 = torch.constant.int 4 + %int4096_6981 = torch.constant.int 4096 + %6514 = torch.prim.ListConstruct %int4_6980, %int4096_6981 : (!torch.int, !torch.int) -> !torch.list + %6515 = torch.aten.view %6511, %6514 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6516 = torch.aten.mm %6515, %6513 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_6982 = torch.constant.int 4 + %int1_6983 = torch.constant.int 1 + %int4096_6984 = torch.constant.int 4096 + %6517 = torch.prim.ListConstruct %int4_6982, %int1_6983, %int4096_6984 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6518 = torch.aten.view %6516, %6517 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_6985 = torch.constant.int -2 + %int-1_6986 = torch.constant.int -1 + %6519 = torch.aten.transpose.int %396, %int-2_6985, %int-1_6986 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6987 = torch.constant.int 5 + %6520 = torch.prims.convert_element_type %6519, %int5_6987 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> %int4_6988 = torch.constant.int 4 - %int1_6989 = torch.constant.int 1 - %5581 = torch.prim.ListConstruct %int4_6988, %int1_6989 : (!torch.int, !torch.int) -> !torch.list - %int1_6990 = torch.constant.int 1 + %int4096_6989 = torch.constant.int 4096 + %6521 = torch.prim.ListConstruct %int4_6988, %int4096_6989 : (!torch.int, !torch.int) -> !torch.list + %6522 = torch.aten.view %6511, %6521 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6523 = torch.aten.mm %6522, %6520 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_6990 = torch.constant.int 4 %int1_6991 = torch.constant.int 1 - %5582 = torch.prim.ListConstruct %int1_6990, %int1_6991 : (!torch.int, !torch.int) -> !torch.list - %int4_6992 = torch.constant.int 4 - %int0_6993 = torch.constant.int 0 - %cpu_6994 = torch.constant.device "cpu" - %false_6995 = torch.constant.bool false - %5583 = torch.aten.empty_strided %5581, %5582, %int4_6992, %int0_6993, %cpu_6994, %false_6995 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int25_6996 = torch.constant.int 25 - %5584 = torch.aten.fill.Scalar %5583, %int25_6996 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_6997 = torch.constant.int 4 - %int1_6998 = torch.constant.int 1 - %5585 = torch.prim.ListConstruct %int4_6997, %int1_6998 : (!torch.int, !torch.int) -> !torch.list - %5586 = torch.aten.repeat %5580, %5585 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_6999 = torch.constant.int 32 - %5587 = torch.aten.mul.Scalar %5576, %int32_6999 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7000 = torch.constant.int 1 - %5588 = torch.aten.add.Tensor %5587, %5584, %int1_7000 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_7001 = torch.constant.int 2 - %5589 = torch.aten.mul.Scalar %5588, %int2_7001 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1024_6992 = torch.constant.int 1024 + %6524 = torch.prim.ListConstruct %int4_6990, %int1_6991, %int1024_6992 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6525 = torch.aten.view %6523, %6524 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_6993 = torch.constant.int -2 + %int-1_6994 = torch.constant.int -1 + %6526 = torch.aten.transpose.int %397, %int-2_6993, %int-1_6994 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_6995 = torch.constant.int 5 + %6527 = torch.prims.convert_element_type %6526, %int5_6995 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_6996 = torch.constant.int 4 + %int4096_6997 = torch.constant.int 4096 + %6528 = torch.prim.ListConstruct %int4_6996, %int4096_6997 : (!torch.int, !torch.int) -> !torch.list + %6529 = torch.aten.view %6511, %6528 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6530 = torch.aten.mm %6529, %6527 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_6998 = torch.constant.int 4 + %int1_6999 = torch.constant.int 1 + %int1024_7000 = torch.constant.int 1024 + %6531 = torch.prim.ListConstruct %int4_6998, %int1_6999, %int1024_7000 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6532 = torch.aten.view %6530, %6531 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_7001 = torch.constant.int 4 %int1_7002 = torch.constant.int 1 - %5590 = torch.aten.add.Tensor %5589, %5586, %int1_7002 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> %int32_7003 = torch.constant.int 32 - %5591 = torch.aten.mul.Scalar %5590, %int32_7003 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7004 = torch.constant.int 1 - %5592 = torch.aten.add.Tensor %5591, %5578, %int1_7004 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %5593 = torch.prim.ListConstruct %5592 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_7005 = torch.constant.bool false - %5594 = torch.aten.index_put %5573, %5593, %5525, %false_7005 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5594, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_7006 = torch.constant.int 32 - %int2_7007 = torch.constant.int 2 - %int32_7008 = torch.constant.int 32 - %int8_7009 = torch.constant.int 8 - %int128_7010 = torch.constant.int 128 - %5595 = torch.prim.ListConstruct %437, %int32_7006, %int2_7007, %int32_7008, %int8_7009, %int128_7010 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5596 = torch.aten.view %5594, %5595 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5596, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_7011 = torch.constant.int 2097152 - %5597 = torch.prim.ListConstruct %437, %int2097152_7011 : (!torch.int, !torch.int) -> !torch.list - %5598 = torch.aten.view %5596, %5597 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5598, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_7012 = torch.constant.int 4 - %5599 = torch.prim.ListConstruct %int4_7012, %358 : (!torch.int, !torch.int) -> !torch.list + %int128_7004 = torch.constant.int 128 + %6533 = torch.prim.ListConstruct %int4_7001, %int1_7002, %int32_7003, %int128_7004 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6534 = torch.aten.view %6518, %6533 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_7005 = torch.constant.int 4 + %int1_7006 = torch.constant.int 1 + %int8_7007 = torch.constant.int 8 + %int128_7008 = torch.constant.int 128 + %6535 = torch.prim.ListConstruct %int4_7005, %int1_7006, %int8_7007, %int128_7008 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6536 = torch.aten.view %6525, %6535 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_7009 = torch.constant.int 4 + %int1_7010 = torch.constant.int 1 + %int8_7011 = torch.constant.int 8 + %int128_7012 = torch.constant.int 128 + %6537 = torch.prim.ListConstruct %int4_7009, %int1_7010, %int8_7011, %int128_7012 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6538 = torch.aten.view %6532, %6537 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> %int1_7013 = torch.constant.int 1 - %5600 = torch.prim.ListConstruct %358, %int1_7013 : (!torch.int, !torch.int) -> !torch.list - %int4_7014 = torch.constant.int 4 - %int0_7015 = torch.constant.int 0 - %cpu_7016 = torch.constant.device "cpu" - %false_7017 = torch.constant.bool false - %5601 = torch.aten.empty_strided %5599, %5600, %int4_7014, %int0_7015, %cpu_7016, %false_7017 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5601, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int25_7018 = torch.constant.int 25 - %5602 = torch.aten.fill.Scalar %5601, %int25_7018 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5602, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_7019 = torch.constant.int 32 - %5603 = torch.aten.mul.Scalar %arg3, %int32_7019 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5603, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_7020 = torch.constant.int 1 - %5604 = torch.aten.add.Tensor %5603, %5602, %int1_7020 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5604, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_7021 = torch.constant.int 4 - %5605 = torch.aten.mul.int %int4_7021, %358 : !torch.int, !torch.int -> !torch.int - %5606 = torch.prim.ListConstruct %5605 : (!torch.int) -> !torch.list - %5607 = torch.aten.view %5604, %5606 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %5607, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_7022 = torch.constant.int 32 - %int2_7023 = torch.constant.int 2 - %int32_7024 = torch.constant.int 32 - %int8_7025 = torch.constant.int 8 - %int128_7026 = torch.constant.int 128 - %5608 = torch.prim.ListConstruct %437, %int32_7022, %int2_7023, %int32_7024, %int8_7025, %int128_7026 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5609 = torch.aten.view %5598, %5608 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5609, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_7027 = torch.constant.int 32 - %5610 = torch.aten.mul.int %437, %int32_7027 : !torch.int, !torch.int -> !torch.int + %int2_7014 = torch.constant.int 2 + %6539 = torch.aten.transpose.int %6534, %int1_7013, %int2_7014 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %6540 = torch.aten.mul.Tensor %6539, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_7015 = torch.constant.int 3 + %int0_7016 = torch.constant.int 0 + %int64_7017 = torch.constant.int 64 + %int1_7018 = torch.constant.int 1 + %6541 = torch.aten.slice.Tensor %6539, %int3_7015, %int0_7016, %int64_7017, %int1_7018 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_7019 = torch.constant.int 3 + %int64_7020 = torch.constant.int 64 + %int9223372036854775807_7021 = torch.constant.int 9223372036854775807 + %int1_7022 = torch.constant.int 1 + %6542 = torch.aten.slice.Tensor %6539, %int3_7019, %int64_7020, %int9223372036854775807_7021, %int1_7022 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %6543 = torch.aten.neg %6542 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %6544 = torch.prim.ListConstruct %6543, %6541 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_7023 = torch.constant.int -1 + %6545 = torch.aten.cat %6544, %int-1_7023 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %6546 = torch.aten.mul.Tensor %6545, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_7024 = torch.constant.int 1 + %6547 = torch.aten.add.Tensor %6540, %6546, %int1_7024 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_7025 = torch.constant.int 1 + %int2_7026 = torch.constant.int 2 + %6548 = torch.aten.transpose.int %6547, %int1_7025, %int2_7026 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_7027 = torch.constant.int 1 %int2_7028 = torch.constant.int 2 - %int32_7029 = torch.constant.int 32 - %int8_7030 = torch.constant.int 8 - %int128_7031 = torch.constant.int 128 - %5611 = torch.prim.ListConstruct %5610, %int2_7028, %int32_7029, %int8_7030, %int128_7031 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5612 = torch.aten.view %5609, %5611 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %5612, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_7032 = torch.constant.int 0 - %5613 = torch.aten.index_select %5612, %int0_7032, %5607 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %5613, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_7033 = torch.constant.int 4 - %int2_7034 = torch.constant.int 2 - %int32_7035 = torch.constant.int 32 - %int8_7036 = torch.constant.int 8 - %int128_7037 = torch.constant.int 128 - %5614 = torch.prim.ListConstruct %int4_7033, %358, %int2_7034, %int32_7035, %int8_7036, %int128_7037 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5615 = torch.aten.view %5613, %5614 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5615, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_7038 = torch.constant.int 0 - %int0_7039 = torch.constant.int 0 - %int9223372036854775807_7040 = torch.constant.int 9223372036854775807 - %int1_7041 = torch.constant.int 1 - %5616 = torch.aten.slice.Tensor %5615, %int0_7038, %int0_7039, %int9223372036854775807_7040, %int1_7041 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5616, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> + %6549 = torch.aten.transpose.int %6536, %int1_7027, %int2_7028 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %6550 = torch.aten.mul.Tensor %6549, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_7029 = torch.constant.int 3 + %int0_7030 = torch.constant.int 0 + %int64_7031 = torch.constant.int 64 + %int1_7032 = torch.constant.int 1 + %6551 = torch.aten.slice.Tensor %6549, %int3_7029, %int0_7030, %int64_7031, %int1_7032 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_7033 = torch.constant.int 3 + %int64_7034 = torch.constant.int 64 + %int9223372036854775807_7035 = torch.constant.int 9223372036854775807 + %int1_7036 = torch.constant.int 1 + %6552 = torch.aten.slice.Tensor %6549, %int3_7033, %int64_7034, %int9223372036854775807_7035, %int1_7036 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %6553 = torch.aten.neg %6552 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %6554 = torch.prim.ListConstruct %6553, %6551 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_7037 = torch.constant.int -1 + %6555 = torch.aten.cat %6554, %int-1_7037 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %6556 = torch.aten.mul.Tensor %6555, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_7038 = torch.constant.int 1 + %6557 = torch.aten.add.Tensor %6550, %6556, %int1_7038 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_7039 = torch.constant.int 1 + %int2_7040 = torch.constant.int 2 + %6558 = torch.aten.transpose.int %6557, %int1_7039, %int2_7040 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_7041 = torch.constant.int 32 + %6559 = torch.aten.floor_divide.Scalar %arg2, %int32_7041 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> %int1_7042 = torch.constant.int 1 - %int0_7043 = torch.constant.int 0 - %int9223372036854775807_7044 = torch.constant.int 9223372036854775807 - %int1_7045 = torch.constant.int 1 - %5617 = torch.aten.slice.Tensor %5616, %int1_7042, %int0_7043, %int9223372036854775807_7044, %int1_7045 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5617, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_7046 = torch.constant.int 2 - %int0_7047 = torch.constant.int 0 - %5618 = torch.aten.select.int %5617, %int2_7046, %int0_7047 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5618, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %6560 = torch.aten.unsqueeze %6559, %int1_7042 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_7043 = torch.constant.int 1 + %false_7044 = torch.constant.bool false + %6561 = torch.aten.gather %arg3, %int1_7043, %6560, %false_7044 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_7045 = torch.constant.int 4 + %int1_7046 = torch.constant.int 1 + %int1_7047 = torch.constant.int 1 + %6562 = torch.prim.ListConstruct %int4_7045, %int1_7046, %int1_7047 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6563 = torch.aten.view %6561, %6562 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> %int32_7048 = torch.constant.int 32 - %5619 = torch.aten.mul.int %358, %int32_7048 : !torch.int, !torch.int -> !torch.int - %int2_7049 = torch.constant.int 2 - %int0_7050 = torch.constant.int 0 + %6564 = torch.aten.remainder.Scalar %arg2, %int32_7048 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_7049 = torch.constant.int 4 + %int1_7050 = torch.constant.int 1 %int1_7051 = torch.constant.int 1 - %5620 = torch.aten.slice.Tensor %5618, %int2_7049, %int0_7050, %5619, %int1_7051 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5620, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_7052 = torch.constant.int 0 - %5621 = torch.aten.clone %5620, %int0_7052 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5621, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_7053 = torch.constant.int 1 - %5622 = torch.aten.size.int %5617, %int1_7053 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_7054 = torch.constant.int 32 - %5623 = torch.aten.mul.int %5622, %int32_7054 : !torch.int, !torch.int -> !torch.int - %int4_7055 = torch.constant.int 4 - %int8_7056 = torch.constant.int 8 - %int128_7057 = torch.constant.int 128 - %5624 = torch.prim.ListConstruct %int4_7055, %5623, %int8_7056, %int128_7057 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5625 = torch.aten._unsafe_view %5621, %5624 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5625, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_7058 = torch.constant.int 0 - %int0_7059 = torch.constant.int 0 - %int9223372036854775807_7060 = torch.constant.int 9223372036854775807 + %6565 = torch.prim.ListConstruct %int4_7049, %int1_7050, %int1_7051 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6566 = torch.aten.view %6564, %6565 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_7052 = torch.constant.int 8 + %none_7053 = torch.constant.none + %none_7054 = torch.constant.none + %cpu_7055 = torch.constant.device "cpu" + %false_7056 = torch.constant.bool false + %6567 = torch.aten.arange %int8_7052, %none_7053, %none_7054, %cpu_7055, %false_7056 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_7057 = torch.constant.int 1 + %int1_7058 = torch.constant.int 1 + %int8_7059 = torch.constant.int 8 + %6568 = torch.prim.ListConstruct %int1_7057, %int1_7058, %int8_7059 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6569 = torch.aten.view %6567, %6568 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_7060 = torch.constant.none + %6570 = torch.aten.clone %398, %none_7060 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6571 = torch.aten.detach %6570 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6572 = torch.aten.detach %6571 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6573 = torch.aten.detach %6572 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %int1_7061 = torch.constant.int 1 - %5626 = torch.aten.slice.Tensor %5625, %int0_7058, %int0_7059, %int9223372036854775807_7060, %int1_7061 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5626, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_7062 = torch.constant.int 0 - %int0_7063 = torch.constant.int 0 - %int9223372036854775807_7064 = torch.constant.int 9223372036854775807 + %int1_7062 = torch.constant.int 1 + %int1_7063 = torch.constant.int 1 + %6574 = torch.prim.ListConstruct %int1_7061, %int1_7062, %int1_7063 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6575 = torch.aten.view %6573, %6574 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_7064 = torch.constant.int 32 + %6576 = torch.aten.mul.Scalar %6563, %int32_7064 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int28 = torch.constant.int 28 %int1_7065 = torch.constant.int 1 - %5627 = torch.aten.slice.Tensor %5615, %int0_7062, %int0_7063, %int9223372036854775807_7064, %int1_7065 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5627, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_7066 = torch.constant.int 1 - %int0_7067 = torch.constant.int 0 - %int9223372036854775807_7068 = torch.constant.int 9223372036854775807 + %6577 = torch.aten.add.Scalar %6576, %int28, %int1_7065 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_7066 = torch.constant.int 2 + %6578 = torch.aten.mul.Scalar %6577, %int2_7066 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_7067 = torch.constant.int 1 + %6579 = torch.aten.add.Tensor %6578, %6575, %int1_7067 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_7068 = torch.constant.int 8 + %6580 = torch.aten.mul.Scalar %6579, %int8_7068 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_7069 = torch.constant.int 1 - %5628 = torch.aten.slice.Tensor %5627, %int1_7066, %int0_7067, %int9223372036854775807_7068, %int1_7069 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5628, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_7070 = torch.constant.int 2 + %6581 = torch.aten.add.Tensor %6580, %6569, %int1_7069 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_7070 = torch.constant.int 32 + %6582 = torch.aten.mul.Scalar %6581, %int32_7070 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_7071 = torch.constant.int 1 - %5629 = torch.aten.select.int %5628, %int2_7070, %int1_7071 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5629, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_7072 = torch.constant.int 2 - %int0_7073 = torch.constant.int 0 - %int1_7074 = torch.constant.int 1 - %5630 = torch.aten.slice.Tensor %5629, %int2_7072, %int0_7073, %5619, %int1_7074 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5630, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_7075 = torch.constant.int 0 - %5631 = torch.aten.clone %5630, %int0_7075 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5631, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_7076 = torch.constant.int 1 - %5632 = torch.aten.size.int %5628, %int1_7076 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_7077 = torch.constant.int 32 - %5633 = torch.aten.mul.int %5632, %int32_7077 : !torch.int, !torch.int -> !torch.int - %int4_7078 = torch.constant.int 4 - %int8_7079 = torch.constant.int 8 - %int128_7080 = torch.constant.int 128 - %5634 = torch.prim.ListConstruct %int4_7078, %5633, %int8_7079, %int128_7080 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5635 = torch.aten._unsafe_view %5631, %5634 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5635, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_7081 = torch.constant.int 0 - %int0_7082 = torch.constant.int 0 - %int9223372036854775807_7083 = torch.constant.int 9223372036854775807 - %int1_7084 = torch.constant.int 1 - %5636 = torch.aten.slice.Tensor %5635, %int0_7081, %int0_7082, %int9223372036854775807_7083, %int1_7084 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5636, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_7085 = torch.constant.int -2 - %5637 = torch.aten.unsqueeze %5626, %int-2_7085 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5637, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_7086 = torch.constant.int 1 - %5638 = torch.aten.size.int %5625, %int1_7086 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_7087 = torch.constant.int 4 + %6583 = torch.aten.add.Tensor %6582, %6566, %int1_7071 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_7072 = torch.constant.int 5 + %6584 = torch.prims.convert_element_type %6558, %int5_7072 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_7073 = torch.constant.int 32 + %int2_7074 = torch.constant.int 2 + %int8_7075 = torch.constant.int 8 + %int32_7076 = torch.constant.int 32 + %int128_7077 = torch.constant.int 128 + %6585 = torch.prim.ListConstruct %456, %int32_7073, %int2_7074, %int8_7075, %int32_7076, %int128_7077 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6586 = torch.aten.view %6406, %6585 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6586, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_7078 = torch.constant.int 128 + %6587 = torch.prim.ListConstruct %596, %int128_7078 : (!torch.int, !torch.int) -> !torch.list + %6588 = torch.aten.view %6586, %6587 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6588, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %6589 = torch.prim.ListConstruct %6583 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_7079 = torch.constant.bool false + %6590 = torch.aten.index_put %6588, %6589, %6584, %false_7079 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6590, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_7080 = torch.constant.int 32 + %int2_7081 = torch.constant.int 2 + %int8_7082 = torch.constant.int 8 + %int32_7083 = torch.constant.int 32 + %int128_7084 = torch.constant.int 128 + %6591 = torch.prim.ListConstruct %456, %int32_7080, %int2_7081, %int8_7082, %int32_7083, %int128_7084 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6592 = torch.aten.view %6590, %6591 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6592, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_7085 = torch.constant.int 2097152 + %6593 = torch.prim.ListConstruct %456, %int2097152_7085 : (!torch.int, !torch.int) -> !torch.list + %6594 = torch.aten.view %6592, %6593 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6594, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_7086 = torch.constant.int 32 + %int2_7087 = torch.constant.int 2 %int8_7088 = torch.constant.int 8 - %int4_7089 = torch.constant.int 4 + %int32_7089 = torch.constant.int 32 %int128_7090 = torch.constant.int 128 - %5639 = torch.prim.ListConstruct %int4_7087, %5638, %int8_7088, %int4_7089, %int128_7090 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7091 = torch.constant.bool false - %5640 = torch.aten.expand %5637, %5639, %false_7091 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5640, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_7092 = torch.constant.int 0 - %5641 = torch.aten.clone %5640, %int0_7092 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5641, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7093 = torch.constant.int 4 - %int32_7094 = torch.constant.int 32 - %int128_7095 = torch.constant.int 128 - %5642 = torch.prim.ListConstruct %int4_7093, %5638, %int32_7094, %int128_7095 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5643 = torch.aten._unsafe_view %5641, %5642 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5643, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_7096 = torch.constant.int -2 - %5644 = torch.aten.unsqueeze %5636, %int-2_7096 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5644, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_7097 = torch.constant.int 1 - %5645 = torch.aten.size.int %5635, %int1_7097 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_7098 = torch.constant.int 4 - %int8_7099 = torch.constant.int 8 - %int4_7100 = torch.constant.int 4 - %int128_7101 = torch.constant.int 128 - %5646 = torch.prim.ListConstruct %int4_7098, %5645, %int8_7099, %int4_7100, %int128_7101 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7102 = torch.constant.bool false - %5647 = torch.aten.expand %5644, %5646, %false_7102 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5647, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_7103 = torch.constant.int 0 - %5648 = torch.aten.clone %5647, %int0_7103 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5648, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7104 = torch.constant.int 4 - %int32_7105 = torch.constant.int 32 - %int128_7106 = torch.constant.int 128 - %5649 = torch.prim.ListConstruct %int4_7104, %5645, %int32_7105, %int128_7106 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5650 = torch.aten._unsafe_view %5648, %5649 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5650, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_7107 = torch.constant.int 1 + %6595 = torch.prim.ListConstruct %456, %int32_7086, %int2_7087, %int8_7088, %int32_7089, %int128_7090 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6596 = torch.aten.view %6594, %6595 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6596, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_7091 = torch.constant.int 128 + %6597 = torch.prim.ListConstruct %596, %int128_7091 : (!torch.int, !torch.int) -> !torch.list + %6598 = torch.aten.view %6596, %6597 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6598, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_7092 = torch.constant.none + %6599 = torch.aten.clone %399, %none_7092 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6600 = torch.aten.detach %6599 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6601 = torch.aten.detach %6600 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6602 = torch.aten.detach %6601 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_7093 = torch.constant.int 1 + %int1_7094 = torch.constant.int 1 + %int1_7095 = torch.constant.int 1 + %6603 = torch.prim.ListConstruct %int1_7093, %int1_7094, %int1_7095 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6604 = torch.aten.view %6602, %6603 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_7096 = torch.constant.int 32 + %6605 = torch.aten.mul.Scalar %6563, %int32_7096 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int28_7097 = torch.constant.int 28 + %int1_7098 = torch.constant.int 1 + %6606 = torch.aten.add.Scalar %6605, %int28_7097, %int1_7098 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_7099 = torch.constant.int 2 + %6607 = torch.aten.mul.Scalar %6606, %int2_7099 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_7100 = torch.constant.int 1 + %6608 = torch.aten.add.Tensor %6607, %6604, %int1_7100 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_7101 = torch.constant.int 8 + %6609 = torch.aten.mul.Scalar %6608, %int8_7101 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_7102 = torch.constant.int 1 + %6610 = torch.aten.add.Tensor %6609, %6569, %int1_7102 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_7103 = torch.constant.int 32 + %6611 = torch.aten.mul.Scalar %6610, %int32_7103 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_7104 = torch.constant.int 1 + %6612 = torch.aten.add.Tensor %6611, %6566, %int1_7104 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_7105 = torch.constant.int 5 + %6613 = torch.prims.convert_element_type %6538, %int5_7105 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %6614 = torch.prim.ListConstruct %6612 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_7106 = torch.constant.bool false + %6615 = torch.aten.index_put %6598, %6614, %6613, %false_7106 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6615, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_7107 = torch.constant.int 32 %int2_7108 = torch.constant.int 2 - %5651 = torch.aten.transpose.int %5531, %int1_7107, %int2_7108 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_7109 = torch.constant.int 1 - %int2_7110 = torch.constant.int 2 - %5652 = torch.aten.transpose.int %5643, %int1_7109, %int2_7110 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5652, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_7111 = torch.constant.int 1 - %int2_7112 = torch.constant.int 2 - %5653 = torch.aten.transpose.int %5650, %int1_7111, %int2_7112 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5653, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_7113 = torch.constant.float 0.000000e+00 - %false_7114 = torch.constant.bool false + %int8_7109 = torch.constant.int 8 + %int32_7110 = torch.constant.int 32 + %int128_7111 = torch.constant.int 128 + %6616 = torch.prim.ListConstruct %456, %int32_7107, %int2_7108, %int8_7109, %int32_7110, %int128_7111 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6617 = torch.aten.view %6615, %6616 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6617, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_7112 = torch.constant.int 2097152 + %6618 = torch.prim.ListConstruct %456, %int2097152_7112 : (!torch.int, !torch.int) -> !torch.list + %6619 = torch.aten.view %6617, %6618 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6619, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_7113 = torch.constant.none + %6620 = torch.aten.clone %400, %none_7113 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6621 = torch.aten.detach %6620 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6622 = torch.aten.detach %6621 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6623 = torch.aten.detach %6622 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_7114 = torch.constant.none + %6624 = torch.aten.clone %401, %none_7114 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6625 = torch.aten.detach %6624 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6626 = torch.aten.detach %6625 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6627 = torch.aten.detach %6626 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %none_7115 = torch.constant.none - %5654:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5651, %5652, %5653, %float0.000000e00_7113, %false_7114, %368, %none_7115) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_7116 = torch.constant.int 1 + %6628 = torch.aten.clone %402, %none_7115 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6629 = torch.aten.detach %6628 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6630 = torch.aten.detach %6629 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6631 = torch.aten.detach %6630 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_7116 = torch.constant.int 32 %int2_7117 = torch.constant.int 2 - %5655 = torch.aten.transpose.int %5654#0, %int1_7116, %int2_7117 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_7118 = torch.constant.int 4 - %int1_7119 = torch.constant.int 1 - %int4096_7120 = torch.constant.int 4096 - %5656 = torch.prim.ListConstruct %int4_7118, %int1_7119, %int4096_7120 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5657 = torch.aten.view %5655, %5656 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_7121 = torch.constant.int -2 - %int-1_7122 = torch.constant.int -1 - %5658 = torch.aten.transpose.int %282, %int-2_7121, %int-1_7122 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7123 = torch.constant.int 4 - %int4096_7124 = torch.constant.int 4096 - %5659 = torch.prim.ListConstruct %int4_7123, %int4096_7124 : (!torch.int, !torch.int) -> !torch.list - %5660 = torch.aten.view %5657, %5659 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5661 = torch.aten.mm %5660, %5658 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_7125 = torch.constant.int 4 - %int1_7126 = torch.constant.int 1 - %int4096_7127 = torch.constant.int 4096 - %5662 = torch.prim.ListConstruct %int4_7125, %int1_7126, %int4096_7127 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5663 = torch.aten.view %5661, %5662 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_7128 = torch.constant.int 1 - %5664 = torch.aten.add.Tensor %5491, %5663, %int1_7128 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_7129 = torch.constant.int 6 - %5665 = torch.prims.convert_element_type %5664, %int6_7129 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_7130 = torch.constant.int 2 - %5666 = torch.aten.pow.Tensor_Scalar %5665, %int2_7130 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_7131 = torch.constant.int -1 - %5667 = torch.prim.ListConstruct %int-1_7131 : (!torch.int) -> !torch.list - %true_7132 = torch.constant.bool true - %none_7133 = torch.constant.none - %5668 = torch.aten.mean.dim %5666, %5667, %true_7132, %none_7133 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_7134 = torch.constant.float 9.9999997473787516E-6 - %int1_7135 = torch.constant.int 1 - %5669 = torch.aten.add.Scalar %5668, %float9.999990e-06_7134, %int1_7135 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %5670 = torch.aten.rsqrt %5669 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %5671 = torch.aten.mul.Tensor %5665, %5670 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_7136 = torch.constant.int 5 - %5672 = torch.prims.convert_element_type %5671, %int5_7136 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %5673 = torch.aten.mul.Tensor %283, %5672 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_7137 = torch.constant.int 5 - %5674 = torch.prims.convert_element_type %5673, %int5_7137 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_7138 = torch.constant.int -2 - %int-1_7139 = torch.constant.int -1 - %5675 = torch.aten.transpose.int %284, %int-2_7138, %int-1_7139 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int8_7118 = torch.constant.int 8 + %int32_7119 = torch.constant.int 32 + %int128_7120 = torch.constant.int 128 + %6632 = torch.prim.ListConstruct %456, %int32_7116, %int2_7117, %int8_7118, %int32_7119, %int128_7120 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6633 = torch.aten.view %6619, %6632 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6633, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %6634 = torch_c.to_builtin_tensor %6633 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %6635 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_7121 = tensor.cast %6635 : tensor<4x?xi64> to tensor + %6636 = torch_c.to_builtin_tensor %6623 : !torch.vtensor<[],si64> -> tensor + %6637 = torch_c.to_builtin_tensor %6627 : !torch.vtensor<[],si64> -> tensor + %6638 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%6634, %cast_7121, %6636, %6637) : (tensor, tensor, tensor, tensor) -> tensor + %cast_7122 = tensor.cast %6638 : tensor to tensor<4x?x8x32x128xf16> + %6639 = torch_c.from_builtin_tensor %cast_7122 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %6639, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %6640 = torch_c.to_builtin_tensor %6633 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %6641 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_7123 = tensor.cast %6641 : tensor<4x?xi64> to tensor + %6642 = torch_c.to_builtin_tensor %6623 : !torch.vtensor<[],si64> -> tensor + %6643 = torch_c.to_builtin_tensor %6631 : !torch.vtensor<[],si64> -> tensor + %6644 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%6640, %cast_7123, %6642, %6643) : (tensor, tensor, tensor, tensor) -> tensor + %cast_7124 = tensor.cast %6644 : tensor to tensor<4x?x8x32x128xf16> + %6645 = torch_c.from_builtin_tensor %cast_7124 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %6645, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_7125 = torch.constant.int 2 + %int3_7126 = torch.constant.int 3 + %6646 = torch.aten.transpose.int %6639, %int2_7125, %int3_7126 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6646, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_7127 = torch.constant.int 0 + %6647 = torch.aten.clone %6646, %int0_7127 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6647, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_7128 = torch.constant.int 4 + %int8_7129 = torch.constant.int 8 + %int128_7130 = torch.constant.int 128 + %6648 = torch.prim.ListConstruct %int4_7128, %457, %int8_7129, %int128_7130 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6649 = torch.aten._unsafe_view %6647, %6648 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6649, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_7131 = torch.constant.int 2 + %int3_7132 = torch.constant.int 3 + %6650 = torch.aten.transpose.int %6645, %int2_7131, %int3_7132 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6650, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_7133 = torch.constant.int 0 + %6651 = torch.aten.clone %6650, %int0_7133 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6651, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_7134 = torch.constant.int 4 + %int8_7135 = torch.constant.int 8 + %int128_7136 = torch.constant.int 128 + %6652 = torch.prim.ListConstruct %int4_7134, %457, %int8_7135, %int128_7136 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6653 = torch.aten._unsafe_view %6651, %6652 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6653, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_7137 = torch.constant.int -2 + %6654 = torch.aten.unsqueeze %6649, %int-2_7137 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6654, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_7138 = torch.constant.int 4 + %int8_7139 = torch.constant.int 8 %int4_7140 = torch.constant.int 4 - %int4096_7141 = torch.constant.int 4096 - %5676 = torch.prim.ListConstruct %int4_7140, %int4096_7141 : (!torch.int, !torch.int) -> !torch.list - %5677 = torch.aten.view %5674, %5676 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5678 = torch.aten.mm %5677, %5675 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_7142 = torch.constant.int 4 - %int1_7143 = torch.constant.int 1 - %int14336_7144 = torch.constant.int 14336 - %5679 = torch.prim.ListConstruct %int4_7142, %int1_7143, %int14336_7144 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5680 = torch.aten.view %5678, %5679 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %5681 = torch.aten.silu %5680 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_7145 = torch.constant.int -2 - %int-1_7146 = torch.constant.int -1 - %5682 = torch.aten.transpose.int %285, %int-2_7145, %int-1_7146 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_7147 = torch.constant.int 4 - %int4096_7148 = torch.constant.int 4096 - %5683 = torch.prim.ListConstruct %int4_7147, %int4096_7148 : (!torch.int, !torch.int) -> !torch.list - %5684 = torch.aten.view %5674, %5683 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5685 = torch.aten.mm %5684, %5682 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_7149 = torch.constant.int 4 - %int1_7150 = torch.constant.int 1 - %int14336_7151 = torch.constant.int 14336 - %5686 = torch.prim.ListConstruct %int4_7149, %int1_7150, %int14336_7151 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5687 = torch.aten.view %5685, %5686 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %5688 = torch.aten.mul.Tensor %5681, %5687 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_7152 = torch.constant.int -2 - %int-1_7153 = torch.constant.int -1 - %5689 = torch.aten.transpose.int %286, %int-2_7152, %int-1_7153 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int128_7141 = torch.constant.int 128 + %6655 = torch.prim.ListConstruct %int4_7138, %457, %int8_7139, %int4_7140, %int128_7141 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_7142 = torch.constant.bool false + %6656 = torch.aten.expand %6654, %6655, %false_7142 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6656, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_7143 = torch.constant.int 0 + %6657 = torch.aten.clone %6656, %int0_7143 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6657, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_7144 = torch.constant.int 4 + %int32_7145 = torch.constant.int 32 + %int128_7146 = torch.constant.int 128 + %6658 = torch.prim.ListConstruct %int4_7144, %457, %int32_7145, %int128_7146 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6659 = torch.aten._unsafe_view %6657, %6658 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6659, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_7147 = torch.constant.int -2 + %6660 = torch.aten.unsqueeze %6653, %int-2_7147 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6660, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_7148 = torch.constant.int 4 + %int8_7149 = torch.constant.int 8 + %int4_7150 = torch.constant.int 4 + %int128_7151 = torch.constant.int 128 + %6661 = torch.prim.ListConstruct %int4_7148, %457, %int8_7149, %int4_7150, %int128_7151 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_7152 = torch.constant.bool false + %6662 = torch.aten.expand %6660, %6661, %false_7152 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6662, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_7153 = torch.constant.int 0 + %6663 = torch.aten.clone %6662, %int0_7153 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6663, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_7154 = torch.constant.int 4 - %int14336_7155 = torch.constant.int 14336 - %5690 = torch.prim.ListConstruct %int4_7154, %int14336_7155 : (!torch.int, !torch.int) -> !torch.list - %5691 = torch.aten.view %5688, %5690 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %5692 = torch.aten.mm %5691, %5689 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_7156 = torch.constant.int 4 + %int32_7155 = torch.constant.int 32 + %int128_7156 = torch.constant.int 128 + %6664 = torch.prim.ListConstruct %int4_7154, %457, %int32_7155, %int128_7156 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6665 = torch.aten._unsafe_view %6663, %6664 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6665, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> %int1_7157 = torch.constant.int 1 - %int4096_7158 = torch.constant.int 4096 - %5693 = torch.prim.ListConstruct %int4_7156, %int1_7157, %int4096_7158 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5694 = torch.aten.view %5692, %5693 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int2_7158 = torch.constant.int 2 + %6666 = torch.aten.transpose.int %6548, %int1_7157, %int2_7158 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_7159 = torch.constant.int 1 - %5695 = torch.aten.add.Tensor %5664, %5694, %int1_7159 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_7160 = torch.constant.int 6 - %5696 = torch.prims.convert_element_type %5695, %int6_7160 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_7161 = torch.constant.int 2 - %5697 = torch.aten.pow.Tensor_Scalar %5696, %int2_7161 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_7162 = torch.constant.int -1 - %5698 = torch.prim.ListConstruct %int-1_7162 : (!torch.int) -> !torch.list - %true_7163 = torch.constant.bool true - %none_7164 = torch.constant.none - %5699 = torch.aten.mean.dim %5697, %5698, %true_7163, %none_7164 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_7165 = torch.constant.float 9.9999997473787516E-6 + %int2_7160 = torch.constant.int 2 + %6667 = torch.aten.transpose.int %6659, %int1_7159, %int2_7160 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6667, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_7161 = torch.constant.int 1 + %int2_7162 = torch.constant.int 2 + %6668 = torch.aten.transpose.int %6665, %int1_7161, %int2_7162 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6668, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_7163 = torch.constant.float 0.000000e+00 + %false_7164 = torch.constant.bool false + %none_7165 = torch.constant.none + %6669:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6666, %6667, %6668, %float0.000000e00_7163, %false_7164, %470, %none_7165) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) %int1_7166 = torch.constant.int 1 - %5700 = torch.aten.add.Scalar %5699, %float9.999990e-06_7165, %int1_7166 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %5701 = torch.aten.rsqrt %5700 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %5702 = torch.aten.mul.Tensor %5696, %5701 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_7167 = torch.constant.int 5 - %5703 = torch.prims.convert_element_type %5702, %int5_7167 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %5704 = torch.aten.mul.Tensor %287, %5703 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_7168 = torch.constant.int 5 - %5705 = torch.prims.convert_element_type %5704, %int5_7168 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_7169 = torch.constant.int -2 - %int-1_7170 = torch.constant.int -1 - %5706 = torch.aten.transpose.int %288, %int-2_7169, %int-1_7170 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7171 = torch.constant.int 4 - %int4096_7172 = torch.constant.int 4096 - %5707 = torch.prim.ListConstruct %int4_7171, %int4096_7172 : (!torch.int, !torch.int) -> !torch.list - %5708 = torch.aten.view %5705, %5707 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5709 = torch.aten.mm %5708, %5706 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_7173 = torch.constant.int 4 - %int1_7174 = torch.constant.int 1 + %int2_7167 = torch.constant.int 2 + %6670 = torch.aten.transpose.int %6669#0, %int1_7166, %int2_7167 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_7168 = torch.constant.int 4 + %int1_7169 = torch.constant.int 1 + %int4096_7170 = torch.constant.int 4096 + %6671 = torch.prim.ListConstruct %int4_7168, %int1_7169, %int4096_7170 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6672 = torch.aten.view %6670, %6671 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_7171 = torch.constant.int -2 + %int-1_7172 = torch.constant.int -1 + %6673 = torch.aten.transpose.int %403, %int-2_7171, %int-1_7172 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_7173 = torch.constant.int 5 + %6674 = torch.prims.convert_element_type %6673, %int5_7173 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_7174 = torch.constant.int 4 %int4096_7175 = torch.constant.int 4096 - %5710 = torch.prim.ListConstruct %int4_7173, %int1_7174, %int4096_7175 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5711 = torch.aten.view %5709, %5710 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_7176 = torch.constant.int -2 - %int-1_7177 = torch.constant.int -1 - %5712 = torch.aten.transpose.int %289, %int-2_7176, %int-1_7177 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_7178 = torch.constant.int 4 - %int4096_7179 = torch.constant.int 4096 - %5713 = torch.prim.ListConstruct %int4_7178, %int4096_7179 : (!torch.int, !torch.int) -> !torch.list - %5714 = torch.aten.view %5705, %5713 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5715 = torch.aten.mm %5714, %5712 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_7180 = torch.constant.int 4 - %int1_7181 = torch.constant.int 1 - %int1024_7182 = torch.constant.int 1024 - %5716 = torch.prim.ListConstruct %int4_7180, %int1_7181, %int1024_7182 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5717 = torch.aten.view %5715, %5716 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_7183 = torch.constant.int -2 - %int-1_7184 = torch.constant.int -1 - %5718 = torch.aten.transpose.int %290, %int-2_7183, %int-1_7184 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_7185 = torch.constant.int 4 - %int4096_7186 = torch.constant.int 4096 - %5719 = torch.prim.ListConstruct %int4_7185, %int4096_7186 : (!torch.int, !torch.int) -> !torch.list - %5720 = torch.aten.view %5705, %5719 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5721 = torch.aten.mm %5720, %5718 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_7187 = torch.constant.int 4 - %int1_7188 = torch.constant.int 1 - %int1024_7189 = torch.constant.int 1024 - %5722 = torch.prim.ListConstruct %int4_7187, %int1_7188, %int1024_7189 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5723 = torch.aten.view %5721, %5722 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_7190 = torch.constant.int 4 - %int1_7191 = torch.constant.int 1 - %int32_7192 = torch.constant.int 32 - %int128_7193 = torch.constant.int 128 - %5724 = torch.prim.ListConstruct %int4_7190, %int1_7191, %int32_7192, %int128_7193 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5725 = torch.aten.view %5711, %5724 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %6675 = torch.prim.ListConstruct %int4_7174, %int4096_7175 : (!torch.int, !torch.int) -> !torch.list + %6676 = torch.aten.view %6672, %6675 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6677 = torch.aten.mm %6676, %6674 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_7176 = torch.constant.int 4 + %int1_7177 = torch.constant.int 1 + %int4096_7178 = torch.constant.int 4096 + %6678 = torch.prim.ListConstruct %int4_7176, %int1_7177, %int4096_7178 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6679 = torch.aten.view %6677, %6678 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_7179 = torch.constant.int 1 + %6680 = torch.aten.add.Tensor %6501, %6679, %int1_7179 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_7180 = torch.constant.int 6 + %6681 = torch.prims.convert_element_type %6680, %int6_7180 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_7181 = torch.constant.int 2 + %6682 = torch.aten.pow.Tensor_Scalar %6681, %int2_7181 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_7182 = torch.constant.int -1 + %6683 = torch.prim.ListConstruct %int-1_7182 : (!torch.int) -> !torch.list + %true_7183 = torch.constant.bool true + %none_7184 = torch.constant.none + %6684 = torch.aten.mean.dim %6682, %6683, %true_7183, %none_7184 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_7185 = torch.constant.float 9.9999997473787516E-6 + %int1_7186 = torch.constant.int 1 + %6685 = torch.aten.add.Scalar %6684, %float9.999990e-06_7185, %int1_7186 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %6686 = torch.aten.rsqrt %6685 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %6687 = torch.aten.mul.Tensor %6681, %6686 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_7187 = torch.constant.int 5 + %6688 = torch.prims.convert_element_type %6687, %int5_7187 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %6689 = torch.aten.mul.Tensor %404, %6688 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_7188 = torch.constant.int 5 + %6690 = torch.prims.convert_element_type %6689, %int5_7188 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_7189 = torch.constant.int -2 + %int-1_7190 = torch.constant.int -1 + %6691 = torch.aten.transpose.int %405, %int-2_7189, %int-1_7190 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_7191 = torch.constant.int 5 + %6692 = torch.prims.convert_element_type %6691, %int5_7191 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_7192 = torch.constant.int 4 + %int4096_7193 = torch.constant.int 4096 + %6693 = torch.prim.ListConstruct %int4_7192, %int4096_7193 : (!torch.int, !torch.int) -> !torch.list + %6694 = torch.aten.view %6690, %6693 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6695 = torch.aten.mm %6694, %6692 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> %int4_7194 = torch.constant.int 4 %int1_7195 = torch.constant.int 1 - %int8_7196 = torch.constant.int 8 - %int128_7197 = torch.constant.int 128 - %5726 = torch.prim.ListConstruct %int4_7194, %int1_7195, %int8_7196, %int128_7197 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5727 = torch.aten.view %5717, %5726 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_7198 = torch.constant.int 4 - %int1_7199 = torch.constant.int 1 - %int8_7200 = torch.constant.int 8 - %int128_7201 = torch.constant.int 128 - %5728 = torch.prim.ListConstruct %int4_7198, %int1_7199, %int8_7200, %int128_7201 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5729 = torch.aten.view %5723, %5728 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_7202 = torch.constant.int 6 - %5730 = torch.prims.convert_element_type %5725, %int6_7202 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %5731 = torch_c.to_builtin_tensor %5730 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %5732 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %5733 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%5731, %5732) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %5734 = torch_c.from_builtin_tensor %5733 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_7203 = torch.constant.int 5 - %5735 = torch.prims.convert_element_type %5734, %int5_7203 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_7204 = torch.constant.int 6 - %5736 = torch.prims.convert_element_type %5727, %int6_7204 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %5737 = torch_c.to_builtin_tensor %5736 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %5738 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %5739 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%5737, %5738) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %5740 = torch_c.from_builtin_tensor %5739 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_7205 = torch.constant.int 5 - %5741 = torch.prims.convert_element_type %5740, %int5_7205 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_7206 = torch.constant.int 32 - %5742 = torch.aten.floor_divide.Scalar %arg2, %int32_7206 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_7207 = torch.constant.int 1 - %5743 = torch.aten.unsqueeze %5742, %int1_7207 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7208 = torch.constant.int 1 - %false_7209 = torch.constant.bool false - %5744 = torch.aten.gather %arg3, %int1_7208, %5743, %false_7209 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_7210 = torch.constant.int 32 - %5745 = torch.aten.remainder.Scalar %arg2, %int32_7210 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int14336_7196 = torch.constant.int 14336 + %6696 = torch.prim.ListConstruct %int4_7194, %int1_7195, %int14336_7196 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6697 = torch.aten.view %6695, %6696 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %6698 = torch.aten.silu %6697 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_7197 = torch.constant.int -2 + %int-1_7198 = torch.constant.int -1 + %6699 = torch.aten.transpose.int %406, %int-2_7197, %int-1_7198 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_7199 = torch.constant.int 5 + %6700 = torch.prims.convert_element_type %6699, %int5_7199 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_7200 = torch.constant.int 4 + %int4096_7201 = torch.constant.int 4096 + %6701 = torch.prim.ListConstruct %int4_7200, %int4096_7201 : (!torch.int, !torch.int) -> !torch.list + %6702 = torch.aten.view %6690, %6701 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6703 = torch.aten.mm %6702, %6700 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_7202 = torch.constant.int 4 + %int1_7203 = torch.constant.int 1 + %int14336_7204 = torch.constant.int 14336 + %6704 = torch.prim.ListConstruct %int4_7202, %int1_7203, %int14336_7204 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6705 = torch.aten.view %6703, %6704 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %6706 = torch.aten.mul.Tensor %6698, %6705 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_7205 = torch.constant.int -2 + %int-1_7206 = torch.constant.int -1 + %6707 = torch.aten.transpose.int %407, %int-2_7205, %int-1_7206 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_7207 = torch.constant.int 5 + %6708 = torch.prims.convert_element_type %6707, %int5_7207 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_7208 = torch.constant.int 4 + %int14336_7209 = torch.constant.int 14336 + %6709 = torch.prim.ListConstruct %int4_7208, %int14336_7209 : (!torch.int, !torch.int) -> !torch.list + %6710 = torch.aten.view %6706, %6709 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %6711 = torch.aten.mm %6710, %6708 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_7210 = torch.constant.int 4 %int1_7211 = torch.constant.int 1 - %5746 = torch.aten.unsqueeze %5745, %int1_7211 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_7212 = torch.constant.none - %5747 = torch.aten.clone %291, %none_7212 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_7213 = torch.constant.int 0 - %5748 = torch.aten.unsqueeze %5747, %int0_7213 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_7214 = torch.constant.int 4 - %int1_7215 = torch.constant.int 1 - %5749 = torch.prim.ListConstruct %int4_7214, %int1_7215 : (!torch.int, !torch.int) -> !torch.list - %int1_7216 = torch.constant.int 1 - %int1_7217 = torch.constant.int 1 - %5750 = torch.prim.ListConstruct %int1_7216, %int1_7217 : (!torch.int, !torch.int) -> !torch.list - %int4_7218 = torch.constant.int 4 - %int0_7219 = torch.constant.int 0 - %cpu_7220 = torch.constant.device "cpu" - %false_7221 = torch.constant.bool false - %5751 = torch.aten.empty_strided %5749, %5750, %int4_7218, %int0_7219, %cpu_7220, %false_7221 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int26 = torch.constant.int 26 - %5752 = torch.aten.fill.Scalar %5751, %int26 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_7222 = torch.constant.int 4 - %int1_7223 = torch.constant.int 1 - %5753 = torch.prim.ListConstruct %int4_7222, %int1_7223 : (!torch.int, !torch.int) -> !torch.list - %5754 = torch.aten.repeat %5748, %5753 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_7224 = torch.constant.int 32 - %5755 = torch.aten.mul.Scalar %5744, %int32_7224 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7225 = torch.constant.int 1 - %5756 = torch.aten.add.Tensor %5755, %5752, %int1_7225 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_7226 = torch.constant.int 2 - %5757 = torch.aten.mul.Scalar %5756, %int2_7226 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7227 = torch.constant.int 1 - %5758 = torch.aten.add.Tensor %5757, %5754, %int1_7227 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_7228 = torch.constant.int 32 - %5759 = torch.aten.mul.Scalar %5758, %int32_7228 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int4096_7212 = torch.constant.int 4096 + %6712 = torch.prim.ListConstruct %int4_7210, %int1_7211, %int4096_7212 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6713 = torch.aten.view %6711, %6712 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_7213 = torch.constant.int 1 + %6714 = torch.aten.add.Tensor %6680, %6713, %int1_7213 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_7214 = torch.constant.int 6 + %6715 = torch.prims.convert_element_type %6714, %int6_7214 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_7215 = torch.constant.int 2 + %6716 = torch.aten.pow.Tensor_Scalar %6715, %int2_7215 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_7216 = torch.constant.int -1 + %6717 = torch.prim.ListConstruct %int-1_7216 : (!torch.int) -> !torch.list + %true_7217 = torch.constant.bool true + %none_7218 = torch.constant.none + %6718 = torch.aten.mean.dim %6716, %6717, %true_7217, %none_7218 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_7219 = torch.constant.float 9.9999997473787516E-6 + %int1_7220 = torch.constant.int 1 + %6719 = torch.aten.add.Scalar %6718, %float9.999990e-06_7219, %int1_7220 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %6720 = torch.aten.rsqrt %6719 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %6721 = torch.aten.mul.Tensor %6715, %6720 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_7221 = torch.constant.int 5 + %6722 = torch.prims.convert_element_type %6721, %int5_7221 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %6723 = torch.aten.mul.Tensor %408, %6722 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_7222 = torch.constant.int 5 + %6724 = torch.prims.convert_element_type %6723, %int5_7222 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_7223 = torch.constant.int -2 + %int-1_7224 = torch.constant.int -1 + %6725 = torch.aten.transpose.int %409, %int-2_7223, %int-1_7224 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_7225 = torch.constant.int 5 + %6726 = torch.prims.convert_element_type %6725, %int5_7225 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_7226 = torch.constant.int 4 + %int4096_7227 = torch.constant.int 4096 + %6727 = torch.prim.ListConstruct %int4_7226, %int4096_7227 : (!torch.int, !torch.int) -> !torch.list + %6728 = torch.aten.view %6724, %6727 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6729 = torch.aten.mm %6728, %6726 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_7228 = torch.constant.int 4 %int1_7229 = torch.constant.int 1 - %5760 = torch.aten.add.Tensor %5759, %5746, %int1_7229 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_7230 = torch.constant.int 32 - %int2_7231 = torch.constant.int 2 - %int32_7232 = torch.constant.int 32 - %int8_7233 = torch.constant.int 8 - %int128_7234 = torch.constant.int 128 - %5761 = torch.prim.ListConstruct %437, %int32_7230, %int2_7231, %int32_7232, %int8_7233, %int128_7234 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5762 = torch.aten.view %5598, %5761 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5762, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_7235 = torch.constant.int 32 - %5763 = torch.aten.mul.int %437, %int32_7235 : !torch.int, !torch.int -> !torch.int - %int2_7236 = torch.constant.int 2 - %5764 = torch.aten.mul.int %5763, %int2_7236 : !torch.int, !torch.int -> !torch.int - %int32_7237 = torch.constant.int 32 - %5765 = torch.aten.mul.int %5764, %int32_7237 : !torch.int, !torch.int -> !torch.int - %int8_7238 = torch.constant.int 8 - %int128_7239 = torch.constant.int 128 - %5766 = torch.prim.ListConstruct %5765, %int8_7238, %int128_7239 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5767 = torch.aten.view %5762, %5766 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5767, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %5768 = torch.prim.ListConstruct %5760 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_7240 = torch.constant.bool false - %5769 = torch.aten.index_put %5767, %5768, %5741, %false_7240 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5769, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_7241 = torch.constant.int 32 - %int2_7242 = torch.constant.int 2 - %int32_7243 = torch.constant.int 32 - %int8_7244 = torch.constant.int 8 - %int128_7245 = torch.constant.int 128 - %5770 = torch.prim.ListConstruct %437, %int32_7241, %int2_7242, %int32_7243, %int8_7244, %int128_7245 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5771 = torch.aten.view %5769, %5770 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5771, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_7246 = torch.constant.int 2097152 - %5772 = torch.prim.ListConstruct %437, %int2097152_7246 : (!torch.int, !torch.int) -> !torch.list - %5773 = torch.aten.view %5771, %5772 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5773, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_7247 = torch.constant.int 32 - %int2_7248 = torch.constant.int 2 + %int4096_7230 = torch.constant.int 4096 + %6730 = torch.prim.ListConstruct %int4_7228, %int1_7229, %int4096_7230 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6731 = torch.aten.view %6729, %6730 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_7231 = torch.constant.int -2 + %int-1_7232 = torch.constant.int -1 + %6732 = torch.aten.transpose.int %410, %int-2_7231, %int-1_7232 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_7233 = torch.constant.int 5 + %6733 = torch.prims.convert_element_type %6732, %int5_7233 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_7234 = torch.constant.int 4 + %int4096_7235 = torch.constant.int 4096 + %6734 = torch.prim.ListConstruct %int4_7234, %int4096_7235 : (!torch.int, !torch.int) -> !torch.list + %6735 = torch.aten.view %6724, %6734 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6736 = torch.aten.mm %6735, %6733 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_7236 = torch.constant.int 4 + %int1_7237 = torch.constant.int 1 + %int1024_7238 = torch.constant.int 1024 + %6737 = torch.prim.ListConstruct %int4_7236, %int1_7237, %int1024_7238 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6738 = torch.aten.view %6736, %6737 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_7239 = torch.constant.int -2 + %int-1_7240 = torch.constant.int -1 + %6739 = torch.aten.transpose.int %411, %int-2_7239, %int-1_7240 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_7241 = torch.constant.int 5 + %6740 = torch.prims.convert_element_type %6739, %int5_7241 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_7242 = torch.constant.int 4 + %int4096_7243 = torch.constant.int 4096 + %6741 = torch.prim.ListConstruct %int4_7242, %int4096_7243 : (!torch.int, !torch.int) -> !torch.list + %6742 = torch.aten.view %6724, %6741 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6743 = torch.aten.mm %6742, %6740 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_7244 = torch.constant.int 4 + %int1_7245 = torch.constant.int 1 + %int1024_7246 = torch.constant.int 1024 + %6744 = torch.prim.ListConstruct %int4_7244, %int1_7245, %int1024_7246 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6745 = torch.aten.view %6743, %6744 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_7247 = torch.constant.int 4 + %int1_7248 = torch.constant.int 1 %int32_7249 = torch.constant.int 32 - %int8_7250 = torch.constant.int 8 - %int128_7251 = torch.constant.int 128 - %5774 = torch.prim.ListConstruct %437, %int32_7247, %int2_7248, %int32_7249, %int8_7250, %int128_7251 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5775 = torch.aten.view %5773, %5774 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5775, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_7252 = torch.constant.int 8 - %int128_7253 = torch.constant.int 128 - %5776 = torch.prim.ListConstruct %5765, %int8_7252, %int128_7253 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5777 = torch.aten.view %5775, %5776 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5777, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_7254 = torch.constant.int 32 - %5778 = torch.aten.floor_divide.Scalar %arg2, %int32_7254 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_7255 = torch.constant.int 1 - %5779 = torch.aten.unsqueeze %5778, %int1_7255 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int128_7250 = torch.constant.int 128 + %6746 = torch.prim.ListConstruct %int4_7247, %int1_7248, %int32_7249, %int128_7250 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6747 = torch.aten.view %6731, %6746 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_7251 = torch.constant.int 4 + %int1_7252 = torch.constant.int 1 + %int8_7253 = torch.constant.int 8 + %int128_7254 = torch.constant.int 128 + %6748 = torch.prim.ListConstruct %int4_7251, %int1_7252, %int8_7253, %int128_7254 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6749 = torch.aten.view %6738, %6748 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_7255 = torch.constant.int 4 %int1_7256 = torch.constant.int 1 - %false_7257 = torch.constant.bool false - %5780 = torch.aten.gather %arg3, %int1_7256, %5779, %false_7257 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_7258 = torch.constant.int 32 - %5781 = torch.aten.remainder.Scalar %arg2, %int32_7258 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int8_7257 = torch.constant.int 8 + %int128_7258 = torch.constant.int 128 + %6750 = torch.prim.ListConstruct %int4_7255, %int1_7256, %int8_7257, %int128_7258 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6751 = torch.aten.view %6745, %6750 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> %int1_7259 = torch.constant.int 1 - %5782 = torch.aten.unsqueeze %5781, %int1_7259 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_7260 = torch.constant.none - %5783 = torch.aten.clone %292, %none_7260 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_7261 = torch.constant.int 0 - %5784 = torch.aten.unsqueeze %5783, %int0_7261 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_7262 = torch.constant.int 4 - %int1_7263 = torch.constant.int 1 - %5785 = torch.prim.ListConstruct %int4_7262, %int1_7263 : (!torch.int, !torch.int) -> !torch.list + %int2_7260 = torch.constant.int 2 + %6752 = torch.aten.transpose.int %6747, %int1_7259, %int2_7260 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %6753 = torch.aten.mul.Tensor %6752, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_7261 = torch.constant.int 3 + %int0_7262 = torch.constant.int 0 + %int64_7263 = torch.constant.int 64 %int1_7264 = torch.constant.int 1 - %int1_7265 = torch.constant.int 1 - %5786 = torch.prim.ListConstruct %int1_7264, %int1_7265 : (!torch.int, !torch.int) -> !torch.list - %int4_7266 = torch.constant.int 4 - %int0_7267 = torch.constant.int 0 - %cpu_7268 = torch.constant.device "cpu" - %false_7269 = torch.constant.bool false - %5787 = torch.aten.empty_strided %5785, %5786, %int4_7266, %int0_7267, %cpu_7268, %false_7269 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int26_7270 = torch.constant.int 26 - %5788 = torch.aten.fill.Scalar %5787, %int26_7270 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_7271 = torch.constant.int 4 - %int1_7272 = torch.constant.int 1 - %5789 = torch.prim.ListConstruct %int4_7271, %int1_7272 : (!torch.int, !torch.int) -> !torch.list - %5790 = torch.aten.repeat %5784, %5789 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_7273 = torch.constant.int 32 - %5791 = torch.aten.mul.Scalar %5780, %int32_7273 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7274 = torch.constant.int 1 - %5792 = torch.aten.add.Tensor %5791, %5788, %int1_7274 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_7275 = torch.constant.int 2 - %5793 = torch.aten.mul.Scalar %5792, %int2_7275 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7276 = torch.constant.int 1 - %5794 = torch.aten.add.Tensor %5793, %5790, %int1_7276 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_7277 = torch.constant.int 32 - %5795 = torch.aten.mul.Scalar %5794, %int32_7277 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %6754 = torch.aten.slice.Tensor %6752, %int3_7261, %int0_7262, %int64_7263, %int1_7264 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_7265 = torch.constant.int 3 + %int64_7266 = torch.constant.int 64 + %int9223372036854775807_7267 = torch.constant.int 9223372036854775807 + %int1_7268 = torch.constant.int 1 + %6755 = torch.aten.slice.Tensor %6752, %int3_7265, %int64_7266, %int9223372036854775807_7267, %int1_7268 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %6756 = torch.aten.neg %6755 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %6757 = torch.prim.ListConstruct %6756, %6754 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_7269 = torch.constant.int -1 + %6758 = torch.aten.cat %6757, %int-1_7269 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %6759 = torch.aten.mul.Tensor %6758, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_7270 = torch.constant.int 1 + %6760 = torch.aten.add.Tensor %6753, %6759, %int1_7270 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_7271 = torch.constant.int 1 + %int2_7272 = torch.constant.int 2 + %6761 = torch.aten.transpose.int %6760, %int1_7271, %int2_7272 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_7273 = torch.constant.int 1 + %int2_7274 = torch.constant.int 2 + %6762 = torch.aten.transpose.int %6749, %int1_7273, %int2_7274 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %6763 = torch.aten.mul.Tensor %6762, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_7275 = torch.constant.int 3 + %int0_7276 = torch.constant.int 0 + %int64_7277 = torch.constant.int 64 %int1_7278 = torch.constant.int 1 - %5796 = torch.aten.add.Tensor %5795, %5782, %int1_7278 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %5797 = torch.prim.ListConstruct %5796 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_7279 = torch.constant.bool false - %5798 = torch.aten.index_put %5777, %5797, %5729, %false_7279 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5798, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_7280 = torch.constant.int 32 - %int2_7281 = torch.constant.int 2 - %int32_7282 = torch.constant.int 32 - %int8_7283 = torch.constant.int 8 - %int128_7284 = torch.constant.int 128 - %5799 = torch.prim.ListConstruct %437, %int32_7280, %int2_7281, %int32_7282, %int8_7283, %int128_7284 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5800 = torch.aten.view %5798, %5799 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5800, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_7285 = torch.constant.int 2097152 - %5801 = torch.prim.ListConstruct %437, %int2097152_7285 : (!torch.int, !torch.int) -> !torch.list - %5802 = torch.aten.view %5800, %5801 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5802, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_7286 = torch.constant.int 4 - %5803 = torch.prim.ListConstruct %int4_7286, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_7287 = torch.constant.int 1 - %5804 = torch.prim.ListConstruct %358, %int1_7287 : (!torch.int, !torch.int) -> !torch.list - %int4_7288 = torch.constant.int 4 - %int0_7289 = torch.constant.int 0 - %cpu_7290 = torch.constant.device "cpu" - %false_7291 = torch.constant.bool false - %5805 = torch.aten.empty_strided %5803, %5804, %int4_7288, %int0_7289, %cpu_7290, %false_7291 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5805, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int26_7292 = torch.constant.int 26 - %5806 = torch.aten.fill.Scalar %5805, %int26_7292 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5806, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_7293 = torch.constant.int 32 - %5807 = torch.aten.mul.Scalar %arg3, %int32_7293 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5807, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_7294 = torch.constant.int 1 - %5808 = torch.aten.add.Tensor %5807, %5806, %int1_7294 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %5808, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %6764 = torch.aten.slice.Tensor %6762, %int3_7275, %int0_7276, %int64_7277, %int1_7278 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_7279 = torch.constant.int 3 + %int64_7280 = torch.constant.int 64 + %int9223372036854775807_7281 = torch.constant.int 9223372036854775807 + %int1_7282 = torch.constant.int 1 + %6765 = torch.aten.slice.Tensor %6762, %int3_7279, %int64_7280, %int9223372036854775807_7281, %int1_7282 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %6766 = torch.aten.neg %6765 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %6767 = torch.prim.ListConstruct %6766, %6764 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_7283 = torch.constant.int -1 + %6768 = torch.aten.cat %6767, %int-1_7283 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %6769 = torch.aten.mul.Tensor %6768, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_7284 = torch.constant.int 1 + %6770 = torch.aten.add.Tensor %6763, %6769, %int1_7284 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_7285 = torch.constant.int 1 + %int2_7286 = torch.constant.int 2 + %6771 = torch.aten.transpose.int %6770, %int1_7285, %int2_7286 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_7287 = torch.constant.int 32 + %6772 = torch.aten.floor_divide.Scalar %arg2, %int32_7287 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_7288 = torch.constant.int 1 + %6773 = torch.aten.unsqueeze %6772, %int1_7288 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_7289 = torch.constant.int 1 + %false_7290 = torch.constant.bool false + %6774 = torch.aten.gather %arg3, %int1_7289, %6773, %false_7290 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_7291 = torch.constant.int 4 + %int1_7292 = torch.constant.int 1 + %int1_7293 = torch.constant.int 1 + %6775 = torch.prim.ListConstruct %int4_7291, %int1_7292, %int1_7293 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6776 = torch.aten.view %6774, %6775 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_7294 = torch.constant.int 32 + %6777 = torch.aten.remainder.Scalar %arg2, %int32_7294 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> %int4_7295 = torch.constant.int 4 - %5809 = torch.aten.mul.int %int4_7295, %358 : !torch.int, !torch.int -> !torch.int - %5810 = torch.prim.ListConstruct %5809 : (!torch.int) -> !torch.list - %5811 = torch.aten.view %5808, %5810 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %5811, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_7296 = torch.constant.int 32 - %int2_7297 = torch.constant.int 2 - %int32_7298 = torch.constant.int 32 - %int8_7299 = torch.constant.int 8 - %int128_7300 = torch.constant.int 128 - %5812 = torch.prim.ListConstruct %437, %int32_7296, %int2_7297, %int32_7298, %int8_7299, %int128_7300 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5813 = torch.aten.view %5802, %5812 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5813, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_7301 = torch.constant.int 32 - %5814 = torch.aten.mul.int %437, %int32_7301 : !torch.int, !torch.int -> !torch.int - %int2_7302 = torch.constant.int 2 - %int32_7303 = torch.constant.int 32 - %int8_7304 = torch.constant.int 8 - %int128_7305 = torch.constant.int 128 - %5815 = torch.prim.ListConstruct %5814, %int2_7302, %int32_7303, %int8_7304, %int128_7305 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5816 = torch.aten.view %5813, %5815 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %5816, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_7306 = torch.constant.int 0 - %5817 = torch.aten.index_select %5816, %int0_7306, %5811 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %5817, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_7307 = torch.constant.int 4 - %int2_7308 = torch.constant.int 2 - %int32_7309 = torch.constant.int 32 - %int8_7310 = torch.constant.int 8 - %int128_7311 = torch.constant.int 128 - %5818 = torch.prim.ListConstruct %int4_7307, %358, %int2_7308, %int32_7309, %int8_7310, %int128_7311 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5819 = torch.aten.view %5817, %5818 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5819, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_7312 = torch.constant.int 0 - %int0_7313 = torch.constant.int 0 - %int9223372036854775807_7314 = torch.constant.int 9223372036854775807 + %int1_7296 = torch.constant.int 1 + %int1_7297 = torch.constant.int 1 + %6778 = torch.prim.ListConstruct %int4_7295, %int1_7296, %int1_7297 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6779 = torch.aten.view %6777, %6778 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_7298 = torch.constant.int 8 + %none_7299 = torch.constant.none + %none_7300 = torch.constant.none + %cpu_7301 = torch.constant.device "cpu" + %false_7302 = torch.constant.bool false + %6780 = torch.aten.arange %int8_7298, %none_7299, %none_7300, %cpu_7301, %false_7302 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_7303 = torch.constant.int 1 + %int1_7304 = torch.constant.int 1 + %int8_7305 = torch.constant.int 8 + %6781 = torch.prim.ListConstruct %int1_7303, %int1_7304, %int8_7305 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6782 = torch.aten.view %6780, %6781 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_7306 = torch.constant.none + %6783 = torch.aten.clone %412, %none_7306 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6784 = torch.aten.detach %6783 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6785 = torch.aten.detach %6784 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6786 = torch.aten.detach %6785 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_7307 = torch.constant.int 1 + %int1_7308 = torch.constant.int 1 + %int1_7309 = torch.constant.int 1 + %6787 = torch.prim.ListConstruct %int1_7307, %int1_7308, %int1_7309 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6788 = torch.aten.view %6786, %6787 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_7310 = torch.constant.int 32 + %6789 = torch.aten.mul.Scalar %6776, %int32_7310 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int29 = torch.constant.int 29 + %int1_7311 = torch.constant.int 1 + %6790 = torch.aten.add.Scalar %6789, %int29, %int1_7311 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_7312 = torch.constant.int 2 + %6791 = torch.aten.mul.Scalar %6790, %int2_7312 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_7313 = torch.constant.int 1 + %6792 = torch.aten.add.Tensor %6791, %6788, %int1_7313 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_7314 = torch.constant.int 8 + %6793 = torch.aten.mul.Scalar %6792, %int8_7314 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_7315 = torch.constant.int 1 - %5820 = torch.aten.slice.Tensor %5819, %int0_7312, %int0_7313, %int9223372036854775807_7314, %int1_7315 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5820, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_7316 = torch.constant.int 1 - %int0_7317 = torch.constant.int 0 - %int9223372036854775807_7318 = torch.constant.int 9223372036854775807 - %int1_7319 = torch.constant.int 1 - %5821 = torch.aten.slice.Tensor %5820, %int1_7316, %int0_7317, %int9223372036854775807_7318, %int1_7319 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5821, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> + %6794 = torch.aten.add.Tensor %6793, %6782, %int1_7315 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_7316 = torch.constant.int 32 + %6795 = torch.aten.mul.Scalar %6794, %int32_7316 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_7317 = torch.constant.int 1 + %6796 = torch.aten.add.Tensor %6795, %6779, %int1_7317 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_7318 = torch.constant.int 5 + %6797 = torch.prims.convert_element_type %6771, %int5_7318 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_7319 = torch.constant.int 32 %int2_7320 = torch.constant.int 2 - %int0_7321 = torch.constant.int 0 - %5822 = torch.aten.select.int %5821, %int2_7320, %int0_7321 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5822, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int8_7321 = torch.constant.int 8 %int32_7322 = torch.constant.int 32 - %5823 = torch.aten.mul.int %358, %int32_7322 : !torch.int, !torch.int -> !torch.int - %int2_7323 = torch.constant.int 2 - %int0_7324 = torch.constant.int 0 - %int1_7325 = torch.constant.int 1 - %5824 = torch.aten.slice.Tensor %5822, %int2_7323, %int0_7324, %5823, %int1_7325 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5824, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_7326 = torch.constant.int 0 - %5825 = torch.aten.clone %5824, %int0_7326 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5825, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_7327 = torch.constant.int 1 - %5826 = torch.aten.size.int %5821, %int1_7327 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_7328 = torch.constant.int 32 - %5827 = torch.aten.mul.int %5826, %int32_7328 : !torch.int, !torch.int -> !torch.int - %int4_7329 = torch.constant.int 4 - %int8_7330 = torch.constant.int 8 - %int128_7331 = torch.constant.int 128 - %5828 = torch.prim.ListConstruct %int4_7329, %5827, %int8_7330, %int128_7331 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5829 = torch.aten._unsafe_view %5825, %5828 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5829, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_7332 = torch.constant.int 0 - %int0_7333 = torch.constant.int 0 - %int9223372036854775807_7334 = torch.constant.int 9223372036854775807 - %int1_7335 = torch.constant.int 1 - %5830 = torch.aten.slice.Tensor %5829, %int0_7332, %int0_7333, %int9223372036854775807_7334, %int1_7335 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5830, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_7336 = torch.constant.int 0 - %int0_7337 = torch.constant.int 0 - %int9223372036854775807_7338 = torch.constant.int 9223372036854775807 + %int128_7323 = torch.constant.int 128 + %6798 = torch.prim.ListConstruct %456, %int32_7319, %int2_7320, %int8_7321, %int32_7322, %int128_7323 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6799 = torch.aten.view %6619, %6798 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6799, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_7324 = torch.constant.int 128 + %6800 = torch.prim.ListConstruct %596, %int128_7324 : (!torch.int, !torch.int) -> !torch.list + %6801 = torch.aten.view %6799, %6800 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6801, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %6802 = torch.prim.ListConstruct %6796 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_7325 = torch.constant.bool false + %6803 = torch.aten.index_put %6801, %6802, %6797, %false_7325 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6803, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_7326 = torch.constant.int 32 + %int2_7327 = torch.constant.int 2 + %int8_7328 = torch.constant.int 8 + %int32_7329 = torch.constant.int 32 + %int128_7330 = torch.constant.int 128 + %6804 = torch.prim.ListConstruct %456, %int32_7326, %int2_7327, %int8_7328, %int32_7329, %int128_7330 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6805 = torch.aten.view %6803, %6804 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6805, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_7331 = torch.constant.int 2097152 + %6806 = torch.prim.ListConstruct %456, %int2097152_7331 : (!torch.int, !torch.int) -> !torch.list + %6807 = torch.aten.view %6805, %6806 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6807, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_7332 = torch.constant.int 32 + %int2_7333 = torch.constant.int 2 + %int8_7334 = torch.constant.int 8 + %int32_7335 = torch.constant.int 32 + %int128_7336 = torch.constant.int 128 + %6808 = torch.prim.ListConstruct %456, %int32_7332, %int2_7333, %int8_7334, %int32_7335, %int128_7336 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6809 = torch.aten.view %6807, %6808 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6809, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_7337 = torch.constant.int 128 + %6810 = torch.prim.ListConstruct %596, %int128_7337 : (!torch.int, !torch.int) -> !torch.list + %6811 = torch.aten.view %6809, %6810 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6811, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_7338 = torch.constant.none + %6812 = torch.aten.clone %413, %none_7338 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6813 = torch.aten.detach %6812 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6814 = torch.aten.detach %6813 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6815 = torch.aten.detach %6814 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> %int1_7339 = torch.constant.int 1 - %5831 = torch.aten.slice.Tensor %5819, %int0_7336, %int0_7337, %int9223372036854775807_7338, %int1_7339 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5831, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> %int1_7340 = torch.constant.int 1 - %int0_7341 = torch.constant.int 0 - %int9223372036854775807_7342 = torch.constant.int 9223372036854775807 - %int1_7343 = torch.constant.int 1 - %5832 = torch.aten.slice.Tensor %5831, %int1_7340, %int0_7341, %int9223372036854775807_7342, %int1_7343 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %5832, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_7344 = torch.constant.int 2 - %int1_7345 = torch.constant.int 1 - %5833 = torch.aten.select.int %5832, %int2_7344, %int1_7345 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5833, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_7346 = torch.constant.int 2 - %int0_7347 = torch.constant.int 0 + %int1_7341 = torch.constant.int 1 + %6816 = torch.prim.ListConstruct %int1_7339, %int1_7340, %int1_7341 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6817 = torch.aten.view %6815, %6816 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_7342 = torch.constant.int 32 + %6818 = torch.aten.mul.Scalar %6776, %int32_7342 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int29_7343 = torch.constant.int 29 + %int1_7344 = torch.constant.int 1 + %6819 = torch.aten.add.Scalar %6818, %int29_7343, %int1_7344 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_7345 = torch.constant.int 2 + %6820 = torch.aten.mul.Scalar %6819, %int2_7345 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_7346 = torch.constant.int 1 + %6821 = torch.aten.add.Tensor %6820, %6817, %int1_7346 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_7347 = torch.constant.int 8 + %6822 = torch.aten.mul.Scalar %6821, %int8_7347 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_7348 = torch.constant.int 1 - %5834 = torch.aten.slice.Tensor %5833, %int2_7346, %int0_7347, %5823, %int1_7348 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5834, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_7349 = torch.constant.int 0 - %5835 = torch.aten.clone %5834, %int0_7349 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %5835, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %6823 = torch.aten.add.Tensor %6822, %6782, %int1_7348 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_7349 = torch.constant.int 32 + %6824 = torch.aten.mul.Scalar %6823, %int32_7349 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_7350 = torch.constant.int 1 - %5836 = torch.aten.size.int %5832, %int1_7350 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_7351 = torch.constant.int 32 - %5837 = torch.aten.mul.int %5836, %int32_7351 : !torch.int, !torch.int -> !torch.int - %int4_7352 = torch.constant.int 4 - %int8_7353 = torch.constant.int 8 - %int128_7354 = torch.constant.int 128 - %5838 = torch.prim.ListConstruct %int4_7352, %5837, %int8_7353, %int128_7354 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5839 = torch.aten._unsafe_view %5835, %5838 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5839, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_7355 = torch.constant.int 0 - %int0_7356 = torch.constant.int 0 - %int9223372036854775807_7357 = torch.constant.int 9223372036854775807 - %int1_7358 = torch.constant.int 1 - %5840 = torch.aten.slice.Tensor %5839, %int0_7355, %int0_7356, %int9223372036854775807_7357, %int1_7358 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %5840, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_7359 = torch.constant.int -2 - %5841 = torch.aten.unsqueeze %5830, %int-2_7359 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5841, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_7360 = torch.constant.int 1 - %5842 = torch.aten.size.int %5829, %int1_7360 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_7361 = torch.constant.int 4 - %int8_7362 = torch.constant.int 8 - %int4_7363 = torch.constant.int 4 - %int128_7364 = torch.constant.int 128 - %5843 = torch.prim.ListConstruct %int4_7361, %5842, %int8_7362, %int4_7363, %int128_7364 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7365 = torch.constant.bool false - %5844 = torch.aten.expand %5841, %5843, %false_7365 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5844, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_7366 = torch.constant.int 0 - %5845 = torch.aten.clone %5844, %int0_7366 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5845, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7367 = torch.constant.int 4 - %int32_7368 = torch.constant.int 32 - %int128_7369 = torch.constant.int 128 - %5846 = torch.prim.ListConstruct %int4_7367, %5842, %int32_7368, %int128_7369 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5847 = torch.aten._unsafe_view %5845, %5846 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5847, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_7370 = torch.constant.int -2 - %5848 = torch.aten.unsqueeze %5840, %int-2_7370 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %5848, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_7371 = torch.constant.int 1 - %5849 = torch.aten.size.int %5839, %int1_7371 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_7372 = torch.constant.int 4 - %int8_7373 = torch.constant.int 8 + %6825 = torch.aten.add.Tensor %6824, %6779, %int1_7350 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_7351 = torch.constant.int 5 + %6826 = torch.prims.convert_element_type %6751, %int5_7351 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %6827 = torch.prim.ListConstruct %6825 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_7352 = torch.constant.bool false + %6828 = torch.aten.index_put %6811, %6827, %6826, %false_7352 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %6828, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_7353 = torch.constant.int 32 + %int2_7354 = torch.constant.int 2 + %int8_7355 = torch.constant.int 8 + %int32_7356 = torch.constant.int 32 + %int128_7357 = torch.constant.int 128 + %6829 = torch.prim.ListConstruct %456, %int32_7353, %int2_7354, %int8_7355, %int32_7356, %int128_7357 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6830 = torch.aten.view %6828, %6829 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6830, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_7358 = torch.constant.int 2097152 + %6831 = torch.prim.ListConstruct %456, %int2097152_7358 : (!torch.int, !torch.int) -> !torch.list + %6832 = torch.aten.view %6830, %6831 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %6832, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_7359 = torch.constant.none + %6833 = torch.aten.clone %414, %none_7359 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6834 = torch.aten.detach %6833 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6835 = torch.aten.detach %6834 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6836 = torch.aten.detach %6835 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_7360 = torch.constant.none + %6837 = torch.aten.clone %415, %none_7360 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6838 = torch.aten.detach %6837 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6839 = torch.aten.detach %6838 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6840 = torch.aten.detach %6839 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_7361 = torch.constant.none + %6841 = torch.aten.clone %416, %none_7361 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6842 = torch.aten.detach %6841 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6843 = torch.aten.detach %6842 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6844 = torch.aten.detach %6843 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_7362 = torch.constant.int 32 + %int2_7363 = torch.constant.int 2 + %int8_7364 = torch.constant.int 8 + %int32_7365 = torch.constant.int 32 + %int128_7366 = torch.constant.int 128 + %6845 = torch.prim.ListConstruct %456, %int32_7362, %int2_7363, %int8_7364, %int32_7365, %int128_7366 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6846 = torch.aten.view %6832, %6845 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %6846, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %6847 = torch_c.to_builtin_tensor %6846 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %6848 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_7367 = tensor.cast %6848 : tensor<4x?xi64> to tensor + %6849 = torch_c.to_builtin_tensor %6836 : !torch.vtensor<[],si64> -> tensor + %6850 = torch_c.to_builtin_tensor %6840 : !torch.vtensor<[],si64> -> tensor + %6851 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%6847, %cast_7367, %6849, %6850) : (tensor, tensor, tensor, tensor) -> tensor + %cast_7368 = tensor.cast %6851 : tensor to tensor<4x?x8x32x128xf16> + %6852 = torch_c.from_builtin_tensor %cast_7368 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %6852, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %6853 = torch_c.to_builtin_tensor %6846 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %6854 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_7369 = tensor.cast %6854 : tensor<4x?xi64> to tensor + %6855 = torch_c.to_builtin_tensor %6836 : !torch.vtensor<[],si64> -> tensor + %6856 = torch_c.to_builtin_tensor %6844 : !torch.vtensor<[],si64> -> tensor + %6857 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%6853, %cast_7369, %6855, %6856) : (tensor, tensor, tensor, tensor) -> tensor + %cast_7370 = tensor.cast %6857 : tensor to tensor<4x?x8x32x128xf16> + %6858 = torch_c.from_builtin_tensor %cast_7370 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %6858, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_7371 = torch.constant.int 2 + %int3_7372 = torch.constant.int 3 + %6859 = torch.aten.transpose.int %6852, %int2_7371, %int3_7372 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6859, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_7373 = torch.constant.int 0 + %6860 = torch.aten.clone %6859, %int0_7373 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6860, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int4_7374 = torch.constant.int 4 - %int128_7375 = torch.constant.int 128 - %5850 = torch.prim.ListConstruct %int4_7372, %5849, %int8_7373, %int4_7374, %int128_7375 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7376 = torch.constant.bool false - %5851 = torch.aten.expand %5848, %5850, %false_7376 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5851, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_7377 = torch.constant.int 0 - %5852 = torch.aten.clone %5851, %int0_7377 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %5852, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7378 = torch.constant.int 4 - %int32_7379 = torch.constant.int 32 - %int128_7380 = torch.constant.int 128 - %5853 = torch.prim.ListConstruct %int4_7378, %5849, %int32_7379, %int128_7380 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5854 = torch.aten._unsafe_view %5852, %5853 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %5854, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_7381 = torch.constant.int 1 - %int2_7382 = torch.constant.int 2 - %5855 = torch.aten.transpose.int %5735, %int1_7381, %int2_7382 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_7383 = torch.constant.int 1 - %int2_7384 = torch.constant.int 2 - %5856 = torch.aten.transpose.int %5847, %int1_7383, %int2_7384 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5856, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_7385 = torch.constant.int 1 - %int2_7386 = torch.constant.int 2 - %5857 = torch.aten.transpose.int %5854, %int1_7385, %int2_7386 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %5857, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_7387 = torch.constant.float 0.000000e+00 + %int8_7375 = torch.constant.int 8 + %int128_7376 = torch.constant.int 128 + %6861 = torch.prim.ListConstruct %int4_7374, %457, %int8_7375, %int128_7376 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6862 = torch.aten._unsafe_view %6860, %6861 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6862, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_7377 = torch.constant.int 2 + %int3_7378 = torch.constant.int 3 + %6863 = torch.aten.transpose.int %6858, %int2_7377, %int3_7378 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6863, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_7379 = torch.constant.int 0 + %6864 = torch.aten.clone %6863, %int0_7379 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %6864, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_7380 = torch.constant.int 4 + %int8_7381 = torch.constant.int 8 + %int128_7382 = torch.constant.int 128 + %6865 = torch.prim.ListConstruct %int4_7380, %457, %int8_7381, %int128_7382 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6866 = torch.aten._unsafe_view %6864, %6865 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %6866, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_7383 = torch.constant.int -2 + %6867 = torch.aten.unsqueeze %6862, %int-2_7383 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6867, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_7384 = torch.constant.int 4 + %int8_7385 = torch.constant.int 8 + %int4_7386 = torch.constant.int 4 + %int128_7387 = torch.constant.int 128 + %6868 = torch.prim.ListConstruct %int4_7384, %457, %int8_7385, %int4_7386, %int128_7387 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %false_7388 = torch.constant.bool false - %none_7389 = torch.constant.none - %5858:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%5855, %5856, %5857, %float0.000000e00_7387, %false_7388, %368, %none_7389) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_7390 = torch.constant.int 1 - %int2_7391 = torch.constant.int 2 - %5859 = torch.aten.transpose.int %5858#0, %int1_7390, %int2_7391 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_7392 = torch.constant.int 4 - %int1_7393 = torch.constant.int 1 - %int4096_7394 = torch.constant.int 4096 - %5860 = torch.prim.ListConstruct %int4_7392, %int1_7393, %int4096_7394 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5861 = torch.aten.view %5859, %5860 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_7395 = torch.constant.int -2 - %int-1_7396 = torch.constant.int -1 - %5862 = torch.aten.transpose.int %293, %int-2_7395, %int-1_7396 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7397 = torch.constant.int 4 - %int4096_7398 = torch.constant.int 4096 - %5863 = torch.prim.ListConstruct %int4_7397, %int4096_7398 : (!torch.int, !torch.int) -> !torch.list - %5864 = torch.aten.view %5861, %5863 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5865 = torch.aten.mm %5864, %5862 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_7399 = torch.constant.int 4 - %int1_7400 = torch.constant.int 1 - %int4096_7401 = torch.constant.int 4096 - %5866 = torch.prim.ListConstruct %int4_7399, %int1_7400, %int4096_7401 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5867 = torch.aten.view %5865, %5866 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_7402 = torch.constant.int 1 - %5868 = torch.aten.add.Tensor %5695, %5867, %int1_7402 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_7403 = torch.constant.int 6 - %5869 = torch.prims.convert_element_type %5868, %int6_7403 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %6869 = torch.aten.expand %6867, %6868, %false_7388 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6869, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_7389 = torch.constant.int 0 + %6870 = torch.aten.clone %6869, %int0_7389 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6870, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_7390 = torch.constant.int 4 + %int32_7391 = torch.constant.int 32 + %int128_7392 = torch.constant.int 128 + %6871 = torch.prim.ListConstruct %int4_7390, %457, %int32_7391, %int128_7392 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6872 = torch.aten._unsafe_view %6870, %6871 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6872, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_7393 = torch.constant.int -2 + %6873 = torch.aten.unsqueeze %6866, %int-2_7393 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %6873, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_7394 = torch.constant.int 4 + %int8_7395 = torch.constant.int 8 + %int4_7396 = torch.constant.int 4 + %int128_7397 = torch.constant.int 128 + %6874 = torch.prim.ListConstruct %int4_7394, %457, %int8_7395, %int4_7396, %int128_7397 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_7398 = torch.constant.bool false + %6875 = torch.aten.expand %6873, %6874, %false_7398 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6875, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_7399 = torch.constant.int 0 + %6876 = torch.aten.clone %6875, %int0_7399 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %6876, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_7400 = torch.constant.int 4 + %int32_7401 = torch.constant.int 32 + %int128_7402 = torch.constant.int 128 + %6877 = torch.prim.ListConstruct %int4_7400, %457, %int32_7401, %int128_7402 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6878 = torch.aten._unsafe_view %6876, %6877 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %6878, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_7403 = torch.constant.int 1 %int2_7404 = torch.constant.int 2 - %5870 = torch.aten.pow.Tensor_Scalar %5869, %int2_7404 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_7405 = torch.constant.int -1 - %5871 = torch.prim.ListConstruct %int-1_7405 : (!torch.int) -> !torch.list - %true_7406 = torch.constant.bool true - %none_7407 = torch.constant.none - %5872 = torch.aten.mean.dim %5870, %5871, %true_7406, %none_7407 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_7408 = torch.constant.float 9.9999997473787516E-6 - %int1_7409 = torch.constant.int 1 - %5873 = torch.aten.add.Scalar %5872, %float9.999990e-06_7408, %int1_7409 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %5874 = torch.aten.rsqrt %5873 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %5875 = torch.aten.mul.Tensor %5869, %5874 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_7410 = torch.constant.int 5 - %5876 = torch.prims.convert_element_type %5875, %int5_7410 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %5877 = torch.aten.mul.Tensor %294, %5876 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_7411 = torch.constant.int 5 - %5878 = torch.prims.convert_element_type %5877, %int5_7411 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_7412 = torch.constant.int -2 - %int-1_7413 = torch.constant.int -1 - %5879 = torch.aten.transpose.int %295, %int-2_7412, %int-1_7413 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %6879 = torch.aten.transpose.int %6761, %int1_7403, %int2_7404 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_7405 = torch.constant.int 1 + %int2_7406 = torch.constant.int 2 + %6880 = torch.aten.transpose.int %6872, %int1_7405, %int2_7406 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6880, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_7407 = torch.constant.int 1 + %int2_7408 = torch.constant.int 2 + %6881 = torch.aten.transpose.int %6878, %int1_7407, %int2_7408 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %6881, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_7409 = torch.constant.float 0.000000e+00 + %false_7410 = torch.constant.bool false + %none_7411 = torch.constant.none + %6882:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6879, %6880, %6881, %float0.000000e00_7409, %false_7410, %470, %none_7411) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_7412 = torch.constant.int 1 + %int2_7413 = torch.constant.int 2 + %6883 = torch.aten.transpose.int %6882#0, %int1_7412, %int2_7413 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> %int4_7414 = torch.constant.int 4 - %int4096_7415 = torch.constant.int 4096 - %5880 = torch.prim.ListConstruct %int4_7414, %int4096_7415 : (!torch.int, !torch.int) -> !torch.list - %5881 = torch.aten.view %5878, %5880 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5882 = torch.aten.mm %5881, %5879 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_7416 = torch.constant.int 4 - %int1_7417 = torch.constant.int 1 - %int14336_7418 = torch.constant.int 14336 - %5883 = torch.prim.ListConstruct %int4_7416, %int1_7417, %int14336_7418 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5884 = torch.aten.view %5882, %5883 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %5885 = torch.aten.silu %5884 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_7419 = torch.constant.int -2 - %int-1_7420 = torch.constant.int -1 - %5886 = torch.aten.transpose.int %296, %int-2_7419, %int-1_7420 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_7421 = torch.constant.int 4 - %int4096_7422 = torch.constant.int 4096 - %5887 = torch.prim.ListConstruct %int4_7421, %int4096_7422 : (!torch.int, !torch.int) -> !torch.list - %5888 = torch.aten.view %5878, %5887 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5889 = torch.aten.mm %5888, %5886 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_7423 = torch.constant.int 4 - %int1_7424 = torch.constant.int 1 - %int14336_7425 = torch.constant.int 14336 - %5890 = torch.prim.ListConstruct %int4_7423, %int1_7424, %int14336_7425 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5891 = torch.aten.view %5889, %5890 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %5892 = torch.aten.mul.Tensor %5885, %5891 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_7426 = torch.constant.int -2 - %int-1_7427 = torch.constant.int -1 - %5893 = torch.aten.transpose.int %297, %int-2_7426, %int-1_7427 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_7428 = torch.constant.int 4 - %int14336_7429 = torch.constant.int 14336 - %5894 = torch.prim.ListConstruct %int4_7428, %int14336_7429 : (!torch.int, !torch.int) -> !torch.list - %5895 = torch.aten.view %5892, %5894 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %5896 = torch.aten.mm %5895, %5893 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_7430 = torch.constant.int 4 - %int1_7431 = torch.constant.int 1 - %int4096_7432 = torch.constant.int 4096 - %5897 = torch.prim.ListConstruct %int4_7430, %int1_7431, %int4096_7432 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5898 = torch.aten.view %5896, %5897 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_7433 = torch.constant.int 1 - %5899 = torch.aten.add.Tensor %5868, %5898, %int1_7433 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_7434 = torch.constant.int 6 - %5900 = torch.prims.convert_element_type %5899, %int6_7434 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_7435 = torch.constant.int 2 - %5901 = torch.aten.pow.Tensor_Scalar %5900, %int2_7435 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int1_7415 = torch.constant.int 1 + %int4096_7416 = torch.constant.int 4096 + %6884 = torch.prim.ListConstruct %int4_7414, %int1_7415, %int4096_7416 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6885 = torch.aten.view %6883, %6884 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_7417 = torch.constant.int -2 + %int-1_7418 = torch.constant.int -1 + %6886 = torch.aten.transpose.int %417, %int-2_7417, %int-1_7418 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_7419 = torch.constant.int 5 + %6887 = torch.prims.convert_element_type %6886, %int5_7419 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_7420 = torch.constant.int 4 + %int4096_7421 = torch.constant.int 4096 + %6888 = torch.prim.ListConstruct %int4_7420, %int4096_7421 : (!torch.int, !torch.int) -> !torch.list + %6889 = torch.aten.view %6885, %6888 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6890 = torch.aten.mm %6889, %6887 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_7422 = torch.constant.int 4 + %int1_7423 = torch.constant.int 1 + %int4096_7424 = torch.constant.int 4096 + %6891 = torch.prim.ListConstruct %int4_7422, %int1_7423, %int4096_7424 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6892 = torch.aten.view %6890, %6891 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_7425 = torch.constant.int 1 + %6893 = torch.aten.add.Tensor %6714, %6892, %int1_7425 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_7426 = torch.constant.int 6 + %6894 = torch.prims.convert_element_type %6893, %int6_7426 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_7427 = torch.constant.int 2 + %6895 = torch.aten.pow.Tensor_Scalar %6894, %int2_7427 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_7428 = torch.constant.int -1 + %6896 = torch.prim.ListConstruct %int-1_7428 : (!torch.int) -> !torch.list + %true_7429 = torch.constant.bool true + %none_7430 = torch.constant.none + %6897 = torch.aten.mean.dim %6895, %6896, %true_7429, %none_7430 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_7431 = torch.constant.float 9.9999997473787516E-6 + %int1_7432 = torch.constant.int 1 + %6898 = torch.aten.add.Scalar %6897, %float9.999990e-06_7431, %int1_7432 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %6899 = torch.aten.rsqrt %6898 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %6900 = torch.aten.mul.Tensor %6894, %6899 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_7433 = torch.constant.int 5 + %6901 = torch.prims.convert_element_type %6900, %int5_7433 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %6902 = torch.aten.mul.Tensor %418, %6901 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_7434 = torch.constant.int 5 + %6903 = torch.prims.convert_element_type %6902, %int5_7434 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_7435 = torch.constant.int -2 %int-1_7436 = torch.constant.int -1 - %5902 = torch.prim.ListConstruct %int-1_7436 : (!torch.int) -> !torch.list - %true_7437 = torch.constant.bool true - %none_7438 = torch.constant.none - %5903 = torch.aten.mean.dim %5901, %5902, %true_7437, %none_7438 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_7439 = torch.constant.float 9.9999997473787516E-6 - %int1_7440 = torch.constant.int 1 - %5904 = torch.aten.add.Scalar %5903, %float9.999990e-06_7439, %int1_7440 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %5905 = torch.aten.rsqrt %5904 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %5906 = torch.aten.mul.Tensor %5900, %5905 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_7441 = torch.constant.int 5 - %5907 = torch.prims.convert_element_type %5906, %int5_7441 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %5908 = torch.aten.mul.Tensor %298, %5907 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_7442 = torch.constant.int 5 - %5909 = torch.prims.convert_element_type %5908, %int5_7442 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %6904 = torch.aten.transpose.int %419, %int-2_7435, %int-1_7436 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_7437 = torch.constant.int 5 + %6905 = torch.prims.convert_element_type %6904, %int5_7437 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_7438 = torch.constant.int 4 + %int4096_7439 = torch.constant.int 4096 + %6906 = torch.prim.ListConstruct %int4_7438, %int4096_7439 : (!torch.int, !torch.int) -> !torch.list + %6907 = torch.aten.view %6903, %6906 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6908 = torch.aten.mm %6907, %6905 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_7440 = torch.constant.int 4 + %int1_7441 = torch.constant.int 1 + %int14336_7442 = torch.constant.int 14336 + %6909 = torch.prim.ListConstruct %int4_7440, %int1_7441, %int14336_7442 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6910 = torch.aten.view %6908, %6909 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %6911 = torch.aten.silu %6910 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> %int-2_7443 = torch.constant.int -2 %int-1_7444 = torch.constant.int -1 - %5910 = torch.aten.transpose.int %299, %int-2_7443, %int-1_7444 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7445 = torch.constant.int 4 - %int4096_7446 = torch.constant.int 4096 - %5911 = torch.prim.ListConstruct %int4_7445, %int4096_7446 : (!torch.int, !torch.int) -> !torch.list - %5912 = torch.aten.view %5909, %5911 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5913 = torch.aten.mm %5912, %5910 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_7447 = torch.constant.int 4 - %int1_7448 = torch.constant.int 1 - %int4096_7449 = torch.constant.int 4096 - %5914 = torch.prim.ListConstruct %int4_7447, %int1_7448, %int4096_7449 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5915 = torch.aten.view %5913, %5914 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_7450 = torch.constant.int -2 - %int-1_7451 = torch.constant.int -1 - %5916 = torch.aten.transpose.int %300, %int-2_7450, %int-1_7451 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_7452 = torch.constant.int 4 - %int4096_7453 = torch.constant.int 4096 - %5917 = torch.prim.ListConstruct %int4_7452, %int4096_7453 : (!torch.int, !torch.int) -> !torch.list - %5918 = torch.aten.view %5909, %5917 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5919 = torch.aten.mm %5918, %5916 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %6912 = torch.aten.transpose.int %420, %int-2_7443, %int-1_7444 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_7445 = torch.constant.int 5 + %6913 = torch.prims.convert_element_type %6912, %int5_7445 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_7446 = torch.constant.int 4 + %int4096_7447 = torch.constant.int 4096 + %6914 = torch.prim.ListConstruct %int4_7446, %int4096_7447 : (!torch.int, !torch.int) -> !torch.list + %6915 = torch.aten.view %6903, %6914 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6916 = torch.aten.mm %6915, %6913 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_7448 = torch.constant.int 4 + %int1_7449 = torch.constant.int 1 + %int14336_7450 = torch.constant.int 14336 + %6917 = torch.prim.ListConstruct %int4_7448, %int1_7449, %int14336_7450 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6918 = torch.aten.view %6916, %6917 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %6919 = torch.aten.mul.Tensor %6911, %6918 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_7451 = torch.constant.int -2 + %int-1_7452 = torch.constant.int -1 + %6920 = torch.aten.transpose.int %421, %int-2_7451, %int-1_7452 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_7453 = torch.constant.int 5 + %6921 = torch.prims.convert_element_type %6920, %int5_7453 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> %int4_7454 = torch.constant.int 4 - %int1_7455 = torch.constant.int 1 - %int1024_7456 = torch.constant.int 1024 - %5920 = torch.prim.ListConstruct %int4_7454, %int1_7455, %int1024_7456 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5921 = torch.aten.view %5919, %5920 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_7457 = torch.constant.int -2 - %int-1_7458 = torch.constant.int -1 - %5922 = torch.aten.transpose.int %301, %int-2_7457, %int-1_7458 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_7459 = torch.constant.int 4 - %int4096_7460 = torch.constant.int 4096 - %5923 = torch.prim.ListConstruct %int4_7459, %int4096_7460 : (!torch.int, !torch.int) -> !torch.list - %5924 = torch.aten.view %5909, %5923 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %5925 = torch.aten.mm %5924, %5922 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_7461 = torch.constant.int 4 - %int1_7462 = torch.constant.int 1 - %int1024_7463 = torch.constant.int 1024 - %5926 = torch.prim.ListConstruct %int4_7461, %int1_7462, %int1024_7463 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5927 = torch.aten.view %5925, %5926 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_7464 = torch.constant.int 4 - %int1_7465 = torch.constant.int 1 - %int32_7466 = torch.constant.int 32 - %int128_7467 = torch.constant.int 128 - %5928 = torch.prim.ListConstruct %int4_7464, %int1_7465, %int32_7466, %int128_7467 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5929 = torch.aten.view %5915, %5928 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_7468 = torch.constant.int 4 - %int1_7469 = torch.constant.int 1 - %int8_7470 = torch.constant.int 8 - %int128_7471 = torch.constant.int 128 - %5930 = torch.prim.ListConstruct %int4_7468, %int1_7469, %int8_7470, %int128_7471 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5931 = torch.aten.view %5921, %5930 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int14336_7455 = torch.constant.int 14336 + %6922 = torch.prim.ListConstruct %int4_7454, %int14336_7455 : (!torch.int, !torch.int) -> !torch.list + %6923 = torch.aten.view %6919, %6922 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %6924 = torch.aten.mm %6923, %6921 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_7456 = torch.constant.int 4 + %int1_7457 = torch.constant.int 1 + %int4096_7458 = torch.constant.int 4096 + %6925 = torch.prim.ListConstruct %int4_7456, %int1_7457, %int4096_7458 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6926 = torch.aten.view %6924, %6925 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_7459 = torch.constant.int 1 + %6927 = torch.aten.add.Tensor %6893, %6926, %int1_7459 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_7460 = torch.constant.int 6 + %6928 = torch.prims.convert_element_type %6927, %int6_7460 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_7461 = torch.constant.int 2 + %6929 = torch.aten.pow.Tensor_Scalar %6928, %int2_7461 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_7462 = torch.constant.int -1 + %6930 = torch.prim.ListConstruct %int-1_7462 : (!torch.int) -> !torch.list + %true_7463 = torch.constant.bool true + %none_7464 = torch.constant.none + %6931 = torch.aten.mean.dim %6929, %6930, %true_7463, %none_7464 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_7465 = torch.constant.float 9.9999997473787516E-6 + %int1_7466 = torch.constant.int 1 + %6932 = torch.aten.add.Scalar %6931, %float9.999990e-06_7465, %int1_7466 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %6933 = torch.aten.rsqrt %6932 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %6934 = torch.aten.mul.Tensor %6928, %6933 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_7467 = torch.constant.int 5 + %6935 = torch.prims.convert_element_type %6934, %int5_7467 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %6936 = torch.aten.mul.Tensor %422, %6935 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_7468 = torch.constant.int 5 + %6937 = torch.prims.convert_element_type %6936, %int5_7468 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_7469 = torch.constant.int -2 + %int-1_7470 = torch.constant.int -1 + %6938 = torch.aten.transpose.int %423, %int-2_7469, %int-1_7470 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_7471 = torch.constant.int 5 + %6939 = torch.prims.convert_element_type %6938, %int5_7471 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> %int4_7472 = torch.constant.int 4 - %int1_7473 = torch.constant.int 1 - %int8_7474 = torch.constant.int 8 - %int128_7475 = torch.constant.int 128 - %5932 = torch.prim.ListConstruct %int4_7472, %int1_7473, %int8_7474, %int128_7475 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5933 = torch.aten.view %5927, %5932 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_7476 = torch.constant.int 6 - %5934 = torch.prims.convert_element_type %5929, %int6_7476 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %5935 = torch_c.to_builtin_tensor %5934 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %5936 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %5937 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%5935, %5936) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %5938 = torch_c.from_builtin_tensor %5937 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_7477 = torch.constant.int 5 - %5939 = torch.prims.convert_element_type %5938, %int5_7477 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_7478 = torch.constant.int 6 - %5940 = torch.prims.convert_element_type %5931, %int6_7478 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %5941 = torch_c.to_builtin_tensor %5940 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %5942 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %5943 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%5941, %5942) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %5944 = torch_c.from_builtin_tensor %5943 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> + %int4096_7473 = torch.constant.int 4096 + %6940 = torch.prim.ListConstruct %int4_7472, %int4096_7473 : (!torch.int, !torch.int) -> !torch.list + %6941 = torch.aten.view %6937, %6940 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6942 = torch.aten.mm %6941, %6939 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_7474 = torch.constant.int 4 + %int1_7475 = torch.constant.int 1 + %int4096_7476 = torch.constant.int 4096 + %6943 = torch.prim.ListConstruct %int4_7474, %int1_7475, %int4096_7476 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6944 = torch.aten.view %6942, %6943 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_7477 = torch.constant.int -2 + %int-1_7478 = torch.constant.int -1 + %6945 = torch.aten.transpose.int %424, %int-2_7477, %int-1_7478 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> %int5_7479 = torch.constant.int 5 - %5945 = torch.prims.convert_element_type %5944, %int5_7479 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_7480 = torch.constant.int 32 - %5946 = torch.aten.floor_divide.Scalar %arg2, %int32_7480 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_7481 = torch.constant.int 1 - %5947 = torch.aten.unsqueeze %5946, %int1_7481 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7482 = torch.constant.int 1 - %false_7483 = torch.constant.bool false - %5948 = torch.aten.gather %arg3, %int1_7482, %5947, %false_7483 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_7484 = torch.constant.int 32 - %5949 = torch.aten.remainder.Scalar %arg2, %int32_7484 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_7485 = torch.constant.int 1 - %5950 = torch.aten.unsqueeze %5949, %int1_7485 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_7486 = torch.constant.none - %5951 = torch.aten.clone %302, %none_7486 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_7487 = torch.constant.int 0 - %5952 = torch.aten.unsqueeze %5951, %int0_7487 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %6946 = torch.prims.convert_element_type %6945, %int5_7479 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_7480 = torch.constant.int 4 + %int4096_7481 = torch.constant.int 4096 + %6947 = torch.prim.ListConstruct %int4_7480, %int4096_7481 : (!torch.int, !torch.int) -> !torch.list + %6948 = torch.aten.view %6937, %6947 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6949 = torch.aten.mm %6948, %6946 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_7482 = torch.constant.int 4 + %int1_7483 = torch.constant.int 1 + %int1024_7484 = torch.constant.int 1024 + %6950 = torch.prim.ListConstruct %int4_7482, %int1_7483, %int1024_7484 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6951 = torch.aten.view %6949, %6950 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int-2_7485 = torch.constant.int -2 + %int-1_7486 = torch.constant.int -1 + %6952 = torch.aten.transpose.int %425, %int-2_7485, %int-1_7486 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_7487 = torch.constant.int 5 + %6953 = torch.prims.convert_element_type %6952, %int5_7487 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> %int4_7488 = torch.constant.int 4 - %int1_7489 = torch.constant.int 1 - %5953 = torch.prim.ListConstruct %int4_7488, %int1_7489 : (!torch.int, !torch.int) -> !torch.list - %int1_7490 = torch.constant.int 1 + %int4096_7489 = torch.constant.int 4096 + %6954 = torch.prim.ListConstruct %int4_7488, %int4096_7489 : (!torch.int, !torch.int) -> !torch.list + %6955 = torch.aten.view %6937, %6954 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %6956 = torch.aten.mm %6955, %6953 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_7490 = torch.constant.int 4 %int1_7491 = torch.constant.int 1 - %5954 = torch.prim.ListConstruct %int1_7490, %int1_7491 : (!torch.int, !torch.int) -> !torch.list - %int4_7492 = torch.constant.int 4 - %int0_7493 = torch.constant.int 0 - %cpu_7494 = torch.constant.device "cpu" - %false_7495 = torch.constant.bool false - %5955 = torch.aten.empty_strided %5953, %5954, %int4_7492, %int0_7493, %cpu_7494, %false_7495 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int27 = torch.constant.int 27 - %5956 = torch.aten.fill.Scalar %5955, %int27 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_7496 = torch.constant.int 4 - %int1_7497 = torch.constant.int 1 - %5957 = torch.prim.ListConstruct %int4_7496, %int1_7497 : (!torch.int, !torch.int) -> !torch.list - %5958 = torch.aten.repeat %5952, %5957 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_7498 = torch.constant.int 32 - %5959 = torch.aten.mul.Scalar %5948, %int32_7498 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7499 = torch.constant.int 1 - %5960 = torch.aten.add.Tensor %5959, %5956, %int1_7499 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_7500 = torch.constant.int 2 - %5961 = torch.aten.mul.Scalar %5960, %int2_7500 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7501 = torch.constant.int 1 - %5962 = torch.aten.add.Tensor %5961, %5958, %int1_7501 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_7502 = torch.constant.int 32 - %5963 = torch.aten.mul.Scalar %5962, %int32_7502 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7503 = torch.constant.int 1 - %5964 = torch.aten.add.Tensor %5963, %5950, %int1_7503 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_7504 = torch.constant.int 32 - %int2_7505 = torch.constant.int 2 - %int32_7506 = torch.constant.int 32 - %int8_7507 = torch.constant.int 8 - %int128_7508 = torch.constant.int 128 - %5965 = torch.prim.ListConstruct %437, %int32_7504, %int2_7505, %int32_7506, %int8_7507, %int128_7508 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5966 = torch.aten.view %5802, %5965 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5966, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_7509 = torch.constant.int 32 - %5967 = torch.aten.mul.int %437, %int32_7509 : !torch.int, !torch.int -> !torch.int - %int2_7510 = torch.constant.int 2 - %5968 = torch.aten.mul.int %5967, %int2_7510 : !torch.int, !torch.int -> !torch.int - %int32_7511 = torch.constant.int 32 - %5969 = torch.aten.mul.int %5968, %int32_7511 : !torch.int, !torch.int -> !torch.int - %int8_7512 = torch.constant.int 8 - %int128_7513 = torch.constant.int 128 - %5970 = torch.prim.ListConstruct %5969, %int8_7512, %int128_7513 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5971 = torch.aten.view %5966, %5970 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5971, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %5972 = torch.prim.ListConstruct %5964 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_7514 = torch.constant.bool false - %5973 = torch.aten.index_put %5971, %5972, %5945, %false_7514 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5973, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_7515 = torch.constant.int 32 - %int2_7516 = torch.constant.int 2 - %int32_7517 = torch.constant.int 32 - %int8_7518 = torch.constant.int 8 - %int128_7519 = torch.constant.int 128 - %5974 = torch.prim.ListConstruct %437, %int32_7515, %int2_7516, %int32_7517, %int8_7518, %int128_7519 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5975 = torch.aten.view %5973, %5974 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5975, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_7520 = torch.constant.int 2097152 - %5976 = torch.prim.ListConstruct %437, %int2097152_7520 : (!torch.int, !torch.int) -> !torch.list - %5977 = torch.aten.view %5975, %5976 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %5977, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_7521 = torch.constant.int 32 - %int2_7522 = torch.constant.int 2 - %int32_7523 = torch.constant.int 32 - %int8_7524 = torch.constant.int 8 - %int128_7525 = torch.constant.int 128 - %5978 = torch.prim.ListConstruct %437, %int32_7521, %int2_7522, %int32_7523, %int8_7524, %int128_7525 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %5979 = torch.aten.view %5977, %5978 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %5979, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_7526 = torch.constant.int 8 - %int128_7527 = torch.constant.int 128 - %5980 = torch.prim.ListConstruct %5969, %int8_7526, %int128_7527 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %5981 = torch.aten.view %5979, %5980 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %5981, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_7528 = torch.constant.int 32 - %5982 = torch.aten.floor_divide.Scalar %arg2, %int32_7528 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_7529 = torch.constant.int 1 - %5983 = torch.aten.unsqueeze %5982, %int1_7529 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1024_7492 = torch.constant.int 1024 + %6957 = torch.prim.ListConstruct %int4_7490, %int1_7491, %int1024_7492 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6958 = torch.aten.view %6956, %6957 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_7493 = torch.constant.int 4 + %int1_7494 = torch.constant.int 1 + %int32_7495 = torch.constant.int 32 + %int128_7496 = torch.constant.int 128 + %6959 = torch.prim.ListConstruct %int4_7493, %int1_7494, %int32_7495, %int128_7496 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6960 = torch.aten.view %6944, %6959 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_7497 = torch.constant.int 4 + %int1_7498 = torch.constant.int 1 + %int8_7499 = torch.constant.int 8 + %int128_7500 = torch.constant.int 128 + %6961 = torch.prim.ListConstruct %int4_7497, %int1_7498, %int8_7499, %int128_7500 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6962 = torch.aten.view %6951, %6961 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_7501 = torch.constant.int 4 + %int1_7502 = torch.constant.int 1 + %int8_7503 = torch.constant.int 8 + %int128_7504 = torch.constant.int 128 + %6963 = torch.prim.ListConstruct %int4_7501, %int1_7502, %int8_7503, %int128_7504 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %6964 = torch.aten.view %6958, %6963 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_7505 = torch.constant.int 1 + %int2_7506 = torch.constant.int 2 + %6965 = torch.aten.transpose.int %6960, %int1_7505, %int2_7506 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %6966 = torch.aten.mul.Tensor %6965, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_7507 = torch.constant.int 3 + %int0_7508 = torch.constant.int 0 + %int64_7509 = torch.constant.int 64 + %int1_7510 = torch.constant.int 1 + %6967 = torch.aten.slice.Tensor %6965, %int3_7507, %int0_7508, %int64_7509, %int1_7510 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_7511 = torch.constant.int 3 + %int64_7512 = torch.constant.int 64 + %int9223372036854775807_7513 = torch.constant.int 9223372036854775807 + %int1_7514 = torch.constant.int 1 + %6968 = torch.aten.slice.Tensor %6965, %int3_7511, %int64_7512, %int9223372036854775807_7513, %int1_7514 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %6969 = torch.aten.neg %6968 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %6970 = torch.prim.ListConstruct %6969, %6967 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_7515 = torch.constant.int -1 + %6971 = torch.aten.cat %6970, %int-1_7515 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %6972 = torch.aten.mul.Tensor %6971, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_7516 = torch.constant.int 1 + %6973 = torch.aten.add.Tensor %6966, %6972, %int1_7516 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_7517 = torch.constant.int 1 + %int2_7518 = torch.constant.int 2 + %6974 = torch.aten.transpose.int %6973, %int1_7517, %int2_7518 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int1_7519 = torch.constant.int 1 + %int2_7520 = torch.constant.int 2 + %6975 = torch.aten.transpose.int %6962, %int1_7519, %int2_7520 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %6976 = torch.aten.mul.Tensor %6975, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_7521 = torch.constant.int 3 + %int0_7522 = torch.constant.int 0 + %int64_7523 = torch.constant.int 64 + %int1_7524 = torch.constant.int 1 + %6977 = torch.aten.slice.Tensor %6975, %int3_7521, %int0_7522, %int64_7523, %int1_7524 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_7525 = torch.constant.int 3 + %int64_7526 = torch.constant.int 64 + %int9223372036854775807_7527 = torch.constant.int 9223372036854775807 + %int1_7528 = torch.constant.int 1 + %6978 = torch.aten.slice.Tensor %6975, %int3_7525, %int64_7526, %int9223372036854775807_7527, %int1_7528 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %6979 = torch.aten.neg %6978 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %6980 = torch.prim.ListConstruct %6979, %6977 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_7529 = torch.constant.int -1 + %6981 = torch.aten.cat %6980, %int-1_7529 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %6982 = torch.aten.mul.Tensor %6981, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> %int1_7530 = torch.constant.int 1 - %false_7531 = torch.constant.bool false - %5984 = torch.aten.gather %arg3, %int1_7530, %5983, %false_7531 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_7532 = torch.constant.int 32 - %5985 = torch.aten.remainder.Scalar %arg2, %int32_7532 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_7533 = torch.constant.int 1 - %5986 = torch.aten.unsqueeze %5985, %int1_7533 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_7534 = torch.constant.none - %5987 = torch.aten.clone %303, %none_7534 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_7535 = torch.constant.int 0 - %5988 = torch.aten.unsqueeze %5987, %int0_7535 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_7536 = torch.constant.int 4 - %int1_7537 = torch.constant.int 1 - %5989 = torch.prim.ListConstruct %int4_7536, %int1_7537 : (!torch.int, !torch.int) -> !torch.list + %6983 = torch.aten.add.Tensor %6976, %6982, %int1_7530 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %int1_7531 = torch.constant.int 1 + %int2_7532 = torch.constant.int 2 + %6984 = torch.aten.transpose.int %6983, %int1_7531, %int2_7532 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_7533 = torch.constant.int 32 + %6985 = torch.aten.floor_divide.Scalar %arg2, %int32_7533 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_7534 = torch.constant.int 1 + %6986 = torch.aten.unsqueeze %6985, %int1_7534 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_7535 = torch.constant.int 1 + %false_7536 = torch.constant.bool false + %6987 = torch.aten.gather %arg3, %int1_7535, %6986, %false_7536 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_7537 = torch.constant.int 4 %int1_7538 = torch.constant.int 1 %int1_7539 = torch.constant.int 1 - %5990 = torch.prim.ListConstruct %int1_7538, %int1_7539 : (!torch.int, !torch.int) -> !torch.list - %int4_7540 = torch.constant.int 4 - %int0_7541 = torch.constant.int 0 - %cpu_7542 = torch.constant.device "cpu" - %false_7543 = torch.constant.bool false - %5991 = torch.aten.empty_strided %5989, %5990, %int4_7540, %int0_7541, %cpu_7542, %false_7543 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int27_7544 = torch.constant.int 27 - %5992 = torch.aten.fill.Scalar %5991, %int27_7544 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_7545 = torch.constant.int 4 - %int1_7546 = torch.constant.int 1 - %5993 = torch.prim.ListConstruct %int4_7545, %int1_7546 : (!torch.int, !torch.int) -> !torch.list - %5994 = torch.aten.repeat %5988, %5993 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_7547 = torch.constant.int 32 - %5995 = torch.aten.mul.Scalar %5984, %int32_7547 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7548 = torch.constant.int 1 - %5996 = torch.aten.add.Tensor %5995, %5992, %int1_7548 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_7549 = torch.constant.int 2 - %5997 = torch.aten.mul.Scalar %5996, %int2_7549 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %6988 = torch.prim.ListConstruct %int4_7537, %int1_7538, %int1_7539 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6989 = torch.aten.view %6987, %6988 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_7540 = torch.constant.int 32 + %6990 = torch.aten.remainder.Scalar %arg2, %int32_7540 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_7541 = torch.constant.int 4 + %int1_7542 = torch.constant.int 1 + %int1_7543 = torch.constant.int 1 + %6991 = torch.prim.ListConstruct %int4_7541, %int1_7542, %int1_7543 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6992 = torch.aten.view %6990, %6991 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_7544 = torch.constant.int 8 + %none_7545 = torch.constant.none + %none_7546 = torch.constant.none + %cpu_7547 = torch.constant.device "cpu" + %false_7548 = torch.constant.bool false + %6993 = torch.aten.arange %int8_7544, %none_7545, %none_7546, %cpu_7547, %false_7548 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_7549 = torch.constant.int 1 %int1_7550 = torch.constant.int 1 - %5998 = torch.aten.add.Tensor %5997, %5994, %int1_7550 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_7551 = torch.constant.int 32 - %5999 = torch.aten.mul.Scalar %5998, %int32_7551 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7552 = torch.constant.int 1 - %6000 = torch.aten.add.Tensor %5999, %5986, %int1_7552 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %6001 = torch.prim.ListConstruct %6000 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_7553 = torch.constant.bool false - %6002 = torch.aten.index_put %5981, %6001, %5933, %false_7553 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6002, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_7554 = torch.constant.int 32 - %int2_7555 = torch.constant.int 2 + %int8_7551 = torch.constant.int 8 + %6994 = torch.prim.ListConstruct %int1_7549, %int1_7550, %int8_7551 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6995 = torch.aten.view %6993, %6994 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_7552 = torch.constant.none + %6996 = torch.aten.clone %426, %none_7552 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %6997 = torch.aten.detach %6996 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6998 = torch.aten.detach %6997 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %6999 = torch.aten.detach %6998 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_7553 = torch.constant.int 1 + %int1_7554 = torch.constant.int 1 + %int1_7555 = torch.constant.int 1 + %7000 = torch.prim.ListConstruct %int1_7553, %int1_7554, %int1_7555 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7001 = torch.aten.view %6999, %7000 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> %int32_7556 = torch.constant.int 32 - %int8_7557 = torch.constant.int 8 - %int128_7558 = torch.constant.int 128 - %6003 = torch.prim.ListConstruct %437, %int32_7554, %int2_7555, %int32_7556, %int8_7557, %int128_7558 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6004 = torch.aten.view %6002, %6003 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6004, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_7559 = torch.constant.int 2097152 - %6005 = torch.prim.ListConstruct %437, %int2097152_7559 : (!torch.int, !torch.int) -> !torch.list - %6006 = torch.aten.view %6004, %6005 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %6006, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_7560 = torch.constant.int 4 - %6007 = torch.prim.ListConstruct %int4_7560, %358 : (!torch.int, !torch.int) -> !torch.list + %7002 = torch.aten.mul.Scalar %6989, %int32_7556 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int30 = torch.constant.int 30 + %int1_7557 = torch.constant.int 1 + %7003 = torch.aten.add.Scalar %7002, %int30, %int1_7557 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_7558 = torch.constant.int 2 + %7004 = torch.aten.mul.Scalar %7003, %int2_7558 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_7559 = torch.constant.int 1 + %7005 = torch.aten.add.Tensor %7004, %7001, %int1_7559 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_7560 = torch.constant.int 8 + %7006 = torch.aten.mul.Scalar %7005, %int8_7560 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_7561 = torch.constant.int 1 - %6008 = torch.prim.ListConstruct %358, %int1_7561 : (!torch.int, !torch.int) -> !torch.list - %int4_7562 = torch.constant.int 4 - %int0_7563 = torch.constant.int 0 - %cpu_7564 = torch.constant.device "cpu" - %false_7565 = torch.constant.bool false - %6009 = torch.aten.empty_strided %6007, %6008, %int4_7562, %int0_7563, %cpu_7564, %false_7565 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6009, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int27_7566 = torch.constant.int 27 - %6010 = torch.aten.fill.Scalar %6009, %int27_7566 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6010, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_7567 = torch.constant.int 32 - %6011 = torch.aten.mul.Scalar %arg3, %int32_7567 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6011, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_7568 = torch.constant.int 1 - %6012 = torch.aten.add.Tensor %6011, %6010, %int1_7568 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6012, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_7569 = torch.constant.int 4 - %6013 = torch.aten.mul.int %int4_7569, %358 : !torch.int, !torch.int -> !torch.int - %6014 = torch.prim.ListConstruct %6013 : (!torch.int) -> !torch.list - %6015 = torch.aten.view %6012, %6014 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %6015, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_7570 = torch.constant.int 32 - %int2_7571 = torch.constant.int 2 + %7007 = torch.aten.add.Tensor %7006, %6995, %int1_7561 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_7562 = torch.constant.int 32 + %7008 = torch.aten.mul.Scalar %7007, %int32_7562 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_7563 = torch.constant.int 1 + %7009 = torch.aten.add.Tensor %7008, %6992, %int1_7563 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_7564 = torch.constant.int 5 + %7010 = torch.prims.convert_element_type %6984, %int5_7564 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_7565 = torch.constant.int 32 + %int2_7566 = torch.constant.int 2 + %int8_7567 = torch.constant.int 8 + %int32_7568 = torch.constant.int 32 + %int128_7569 = torch.constant.int 128 + %7011 = torch.prim.ListConstruct %456, %int32_7565, %int2_7566, %int8_7567, %int32_7568, %int128_7569 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7012 = torch.aten.view %6832, %7011 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7012, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_7570 = torch.constant.int 128 + %7013 = torch.prim.ListConstruct %596, %int128_7570 : (!torch.int, !torch.int) -> !torch.list + %7014 = torch.aten.view %7012, %7013 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7014, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %7015 = torch.prim.ListConstruct %7009 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_7571 = torch.constant.bool false + %7016 = torch.aten.index_put %7014, %7015, %7010, %false_7571 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7016, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> %int32_7572 = torch.constant.int 32 - %int8_7573 = torch.constant.int 8 - %int128_7574 = torch.constant.int 128 - %6016 = torch.prim.ListConstruct %437, %int32_7570, %int2_7571, %int32_7572, %int8_7573, %int128_7574 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6017 = torch.aten.view %6006, %6016 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6017, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> + %int2_7573 = torch.constant.int 2 + %int8_7574 = torch.constant.int 8 %int32_7575 = torch.constant.int 32 - %6018 = torch.aten.mul.int %437, %int32_7575 : !torch.int, !torch.int -> !torch.int - %int2_7576 = torch.constant.int 2 - %int32_7577 = torch.constant.int 32 - %int8_7578 = torch.constant.int 8 - %int128_7579 = torch.constant.int 128 - %6019 = torch.prim.ListConstruct %6018, %int2_7576, %int32_7577, %int8_7578, %int128_7579 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6020 = torch.aten.view %6017, %6019 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %6020, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_7580 = torch.constant.int 0 - %6021 = torch.aten.index_select %6020, %int0_7580, %6015 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %6021, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_7581 = torch.constant.int 4 - %int2_7582 = torch.constant.int 2 - %int32_7583 = torch.constant.int 32 - %int8_7584 = torch.constant.int 8 - %int128_7585 = torch.constant.int 128 - %6022 = torch.prim.ListConstruct %int4_7581, %358, %int2_7582, %int32_7583, %int8_7584, %int128_7585 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6023 = torch.aten.view %6021, %6022 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6023, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_7586 = torch.constant.int 0 - %int0_7587 = torch.constant.int 0 - %int9223372036854775807_7588 = torch.constant.int 9223372036854775807 - %int1_7589 = torch.constant.int 1 - %6024 = torch.aten.slice.Tensor %6023, %int0_7586, %int0_7587, %int9223372036854775807_7588, %int1_7589 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6024, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> + %int128_7576 = torch.constant.int 128 + %7017 = torch.prim.ListConstruct %456, %int32_7572, %int2_7573, %int8_7574, %int32_7575, %int128_7576 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7018 = torch.aten.view %7016, %7017 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7018, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_7577 = torch.constant.int 2097152 + %7019 = torch.prim.ListConstruct %456, %int2097152_7577 : (!torch.int, !torch.int) -> !torch.list + %7020 = torch.aten.view %7018, %7019 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %7020, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_7578 = torch.constant.int 32 + %int2_7579 = torch.constant.int 2 + %int8_7580 = torch.constant.int 8 + %int32_7581 = torch.constant.int 32 + %int128_7582 = torch.constant.int 128 + %7021 = torch.prim.ListConstruct %456, %int32_7578, %int2_7579, %int8_7580, %int32_7581, %int128_7582 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7022 = torch.aten.view %7020, %7021 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7022, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_7583 = torch.constant.int 128 + %7023 = torch.prim.ListConstruct %596, %int128_7583 : (!torch.int, !torch.int) -> !torch.list + %7024 = torch.aten.view %7022, %7023 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7024, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_7584 = torch.constant.none + %7025 = torch.aten.clone %427, %none_7584 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %7026 = torch.aten.detach %7025 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7027 = torch.aten.detach %7026 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7028 = torch.aten.detach %7027 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_7585 = torch.constant.int 1 + %int1_7586 = torch.constant.int 1 + %int1_7587 = torch.constant.int 1 + %7029 = torch.prim.ListConstruct %int1_7585, %int1_7586, %int1_7587 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7030 = torch.aten.view %7028, %7029 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_7588 = torch.constant.int 32 + %7031 = torch.aten.mul.Scalar %6989, %int32_7588 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int30_7589 = torch.constant.int 30 %int1_7590 = torch.constant.int 1 - %int0_7591 = torch.constant.int 0 - %int9223372036854775807_7592 = torch.constant.int 9223372036854775807 - %int1_7593 = torch.constant.int 1 - %6025 = torch.aten.slice.Tensor %6024, %int1_7590, %int0_7591, %int9223372036854775807_7592, %int1_7593 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6025, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_7594 = torch.constant.int 2 - %int0_7595 = torch.constant.int 0 - %6026 = torch.aten.select.int %6025, %int2_7594, %int0_7595 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6026, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_7596 = torch.constant.int 32 - %6027 = torch.aten.mul.int %358, %int32_7596 : !torch.int, !torch.int -> !torch.int - %int2_7597 = torch.constant.int 2 - %int0_7598 = torch.constant.int 0 - %int1_7599 = torch.constant.int 1 - %6028 = torch.aten.slice.Tensor %6026, %int2_7597, %int0_7598, %6027, %int1_7599 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6028, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_7600 = torch.constant.int 0 - %6029 = torch.aten.clone %6028, %int0_7600 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6029, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_7601 = torch.constant.int 1 - %6030 = torch.aten.size.int %6025, %int1_7601 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int + %7032 = torch.aten.add.Scalar %7031, %int30_7589, %int1_7590 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_7591 = torch.constant.int 2 + %7033 = torch.aten.mul.Scalar %7032, %int2_7591 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_7592 = torch.constant.int 1 + %7034 = torch.aten.add.Tensor %7033, %7030, %int1_7592 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_7593 = torch.constant.int 8 + %7035 = torch.aten.mul.Scalar %7034, %int8_7593 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_7594 = torch.constant.int 1 + %7036 = torch.aten.add.Tensor %7035, %6995, %int1_7594 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_7595 = torch.constant.int 32 + %7037 = torch.aten.mul.Scalar %7036, %int32_7595 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_7596 = torch.constant.int 1 + %7038 = torch.aten.add.Tensor %7037, %6992, %int1_7596 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_7597 = torch.constant.int 5 + %7039 = torch.prims.convert_element_type %6964, %int5_7597 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %7040 = torch.prim.ListConstruct %7038 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_7598 = torch.constant.bool false + %7041 = torch.aten.index_put %7024, %7040, %7039, %false_7598 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7041, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_7599 = torch.constant.int 32 + %int2_7600 = torch.constant.int 2 + %int8_7601 = torch.constant.int 8 %int32_7602 = torch.constant.int 32 - %6031 = torch.aten.mul.int %6030, %int32_7602 : !torch.int, !torch.int -> !torch.int - %int4_7603 = torch.constant.int 4 - %int8_7604 = torch.constant.int 8 - %int128_7605 = torch.constant.int 128 - %6032 = torch.prim.ListConstruct %int4_7603, %6031, %int8_7604, %int128_7605 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6033 = torch.aten._unsafe_view %6029, %6032 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6033, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_7606 = torch.constant.int 0 - %int0_7607 = torch.constant.int 0 - %int9223372036854775807_7608 = torch.constant.int 9223372036854775807 - %int1_7609 = torch.constant.int 1 - %6034 = torch.aten.slice.Tensor %6033, %int0_7606, %int0_7607, %int9223372036854775807_7608, %int1_7609 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6034, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_7610 = torch.constant.int 0 - %int0_7611 = torch.constant.int 0 - %int9223372036854775807_7612 = torch.constant.int 9223372036854775807 - %int1_7613 = torch.constant.int 1 - %6035 = torch.aten.slice.Tensor %6023, %int0_7610, %int0_7611, %int9223372036854775807_7612, %int1_7613 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6035, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_7614 = torch.constant.int 1 - %int0_7615 = torch.constant.int 0 - %int9223372036854775807_7616 = torch.constant.int 9223372036854775807 - %int1_7617 = torch.constant.int 1 - %6036 = torch.aten.slice.Tensor %6035, %int1_7614, %int0_7615, %int9223372036854775807_7616, %int1_7617 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6036, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_7618 = torch.constant.int 2 - %int1_7619 = torch.constant.int 1 - %6037 = torch.aten.select.int %6036, %int2_7618, %int1_7619 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6037, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_7620 = torch.constant.int 2 - %int0_7621 = torch.constant.int 0 - %int1_7622 = torch.constant.int 1 - %6038 = torch.aten.slice.Tensor %6037, %int2_7620, %int0_7621, %6027, %int1_7622 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6038, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_7623 = torch.constant.int 0 - %6039 = torch.aten.clone %6038, %int0_7623 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6039, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_7624 = torch.constant.int 1 - %6040 = torch.aten.size.int %6036, %int1_7624 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_7625 = torch.constant.int 32 - %6041 = torch.aten.mul.int %6040, %int32_7625 : !torch.int, !torch.int -> !torch.int + %int128_7603 = torch.constant.int 128 + %7042 = torch.prim.ListConstruct %456, %int32_7599, %int2_7600, %int8_7601, %int32_7602, %int128_7603 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7043 = torch.aten.view %7041, %7042 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7043, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_7604 = torch.constant.int 2097152 + %7044 = torch.prim.ListConstruct %456, %int2097152_7604 : (!torch.int, !torch.int) -> !torch.list + %7045 = torch.aten.view %7043, %7044 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %7045, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_7605 = torch.constant.none + %7046 = torch.aten.clone %428, %none_7605 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %7047 = torch.aten.detach %7046 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7048 = torch.aten.detach %7047 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7049 = torch.aten.detach %7048 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_7606 = torch.constant.none + %7050 = torch.aten.clone %429, %none_7606 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %7051 = torch.aten.detach %7050 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7052 = torch.aten.detach %7051 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7053 = torch.aten.detach %7052 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_7607 = torch.constant.none + %7054 = torch.aten.clone %430, %none_7607 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %7055 = torch.aten.detach %7054 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7056 = torch.aten.detach %7055 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7057 = torch.aten.detach %7056 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_7608 = torch.constant.int 32 + %int2_7609 = torch.constant.int 2 + %int8_7610 = torch.constant.int 8 + %int32_7611 = torch.constant.int 32 + %int128_7612 = torch.constant.int 128 + %7058 = torch.prim.ListConstruct %456, %int32_7608, %int2_7609, %int8_7610, %int32_7611, %int128_7612 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7059 = torch.aten.view %7045, %7058 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7059, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %7060 = torch_c.to_builtin_tensor %7059 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %7061 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_7613 = tensor.cast %7061 : tensor<4x?xi64> to tensor + %7062 = torch_c.to_builtin_tensor %7049 : !torch.vtensor<[],si64> -> tensor + %7063 = torch_c.to_builtin_tensor %7053 : !torch.vtensor<[],si64> -> tensor + %7064 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%7060, %cast_7613, %7062, %7063) : (tensor, tensor, tensor, tensor) -> tensor + %cast_7614 = tensor.cast %7064 : tensor to tensor<4x?x8x32x128xf16> + %7065 = torch_c.from_builtin_tensor %cast_7614 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %7065, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %7066 = torch_c.to_builtin_tensor %7059 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %7067 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_7615 = tensor.cast %7067 : tensor<4x?xi64> to tensor + %7068 = torch_c.to_builtin_tensor %7049 : !torch.vtensor<[],si64> -> tensor + %7069 = torch_c.to_builtin_tensor %7057 : !torch.vtensor<[],si64> -> tensor + %7070 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%7066, %cast_7615, %7068, %7069) : (tensor, tensor, tensor, tensor) -> tensor + %cast_7616 = tensor.cast %7070 : tensor to tensor<4x?x8x32x128xf16> + %7071 = torch_c.from_builtin_tensor %cast_7616 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %7071, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_7617 = torch.constant.int 2 + %int3_7618 = torch.constant.int 3 + %7072 = torch.aten.transpose.int %7065, %int2_7617, %int3_7618 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %7072, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_7619 = torch.constant.int 0 + %7073 = torch.aten.clone %7072, %int0_7619 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %7073, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_7620 = torch.constant.int 4 + %int8_7621 = torch.constant.int 8 + %int128_7622 = torch.constant.int 128 + %7074 = torch.prim.ListConstruct %int4_7620, %457, %int8_7621, %int128_7622 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7075 = torch.aten._unsafe_view %7073, %7074 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7075, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_7623 = torch.constant.int 2 + %int3_7624 = torch.constant.int 3 + %7076 = torch.aten.transpose.int %7071, %int2_7623, %int3_7624 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %7076, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_7625 = torch.constant.int 0 + %7077 = torch.aten.clone %7076, %int0_7625 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %7077, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int4_7626 = torch.constant.int 4 %int8_7627 = torch.constant.int 8 %int128_7628 = torch.constant.int 128 - %6042 = torch.prim.ListConstruct %int4_7626, %6041, %int8_7627, %int128_7628 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6043 = torch.aten._unsafe_view %6039, %6042 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6043, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_7629 = torch.constant.int 0 - %int0_7630 = torch.constant.int 0 - %int9223372036854775807_7631 = torch.constant.int 9223372036854775807 - %int1_7632 = torch.constant.int 1 - %6044 = torch.aten.slice.Tensor %6043, %int0_7629, %int0_7630, %int9223372036854775807_7631, %int1_7632 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6044, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_7633 = torch.constant.int -2 - %6045 = torch.aten.unsqueeze %6034, %int-2_7633 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6045, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_7634 = torch.constant.int 1 - %6046 = torch.aten.size.int %6033, %int1_7634 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_7635 = torch.constant.int 4 - %int8_7636 = torch.constant.int 8 - %int4_7637 = torch.constant.int 4 + %7078 = torch.prim.ListConstruct %int4_7626, %457, %int8_7627, %int128_7628 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7079 = torch.aten._unsafe_view %7077, %7078 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7079, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_7629 = torch.constant.int -2 + %7080 = torch.aten.unsqueeze %7075, %int-2_7629 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %7080, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_7630 = torch.constant.int 4 + %int8_7631 = torch.constant.int 8 + %int4_7632 = torch.constant.int 4 + %int128_7633 = torch.constant.int 128 + %7081 = torch.prim.ListConstruct %int4_7630, %457, %int8_7631, %int4_7632, %int128_7633 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_7634 = torch.constant.bool false + %7082 = torch.aten.expand %7080, %7081, %false_7634 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7082, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_7635 = torch.constant.int 0 + %7083 = torch.aten.clone %7082, %int0_7635 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7083, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_7636 = torch.constant.int 4 + %int32_7637 = torch.constant.int 32 %int128_7638 = torch.constant.int 128 - %6047 = torch.prim.ListConstruct %int4_7635, %6046, %int8_7636, %int4_7637, %int128_7638 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7639 = torch.constant.bool false - %6048 = torch.aten.expand %6045, %6047, %false_7639 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6048, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_7640 = torch.constant.int 0 - %6049 = torch.aten.clone %6048, %int0_7640 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6049, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7641 = torch.constant.int 4 - %int32_7642 = torch.constant.int 32 + %7084 = torch.prim.ListConstruct %int4_7636, %457, %int32_7637, %int128_7638 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7085 = torch.aten._unsafe_view %7083, %7084 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7085, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_7639 = torch.constant.int -2 + %7086 = torch.aten.unsqueeze %7079, %int-2_7639 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %7086, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_7640 = torch.constant.int 4 + %int8_7641 = torch.constant.int 8 + %int4_7642 = torch.constant.int 4 %int128_7643 = torch.constant.int 128 - %6050 = torch.prim.ListConstruct %int4_7641, %6046, %int32_7642, %int128_7643 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6051 = torch.aten._unsafe_view %6049, %6050 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6051, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_7644 = torch.constant.int -2 - %6052 = torch.aten.unsqueeze %6044, %int-2_7644 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6052, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_7645 = torch.constant.int 1 - %6053 = torch.aten.size.int %6043, %int1_7645 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int + %7087 = torch.prim.ListConstruct %int4_7640, %457, %int8_7641, %int4_7642, %int128_7643 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_7644 = torch.constant.bool false + %7088 = torch.aten.expand %7086, %7087, %false_7644 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7088, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_7645 = torch.constant.int 0 + %7089 = torch.aten.clone %7088, %int0_7645 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7089, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int4_7646 = torch.constant.int 4 - %int8_7647 = torch.constant.int 8 - %int4_7648 = torch.constant.int 4 - %int128_7649 = torch.constant.int 128 - %6054 = torch.prim.ListConstruct %int4_7646, %6053, %int8_7647, %int4_7648, %int128_7649 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7650 = torch.constant.bool false - %6055 = torch.aten.expand %6052, %6054, %false_7650 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6055, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_7651 = torch.constant.int 0 - %6056 = torch.aten.clone %6055, %int0_7651 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6056, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7652 = torch.constant.int 4 - %int32_7653 = torch.constant.int 32 - %int128_7654 = torch.constant.int 128 - %6057 = torch.prim.ListConstruct %int4_7652, %6053, %int32_7653, %int128_7654 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6058 = torch.aten._unsafe_view %6056, %6057 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6058, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_7655 = torch.constant.int 1 - %int2_7656 = torch.constant.int 2 - %6059 = torch.aten.transpose.int %5939, %int1_7655, %int2_7656 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_7657 = torch.constant.int 1 - %int2_7658 = torch.constant.int 2 - %6060 = torch.aten.transpose.int %6051, %int1_7657, %int2_7658 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6060, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_7659 = torch.constant.int 1 - %int2_7660 = torch.constant.int 2 - %6061 = torch.aten.transpose.int %6058, %int1_7659, %int2_7660 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6061, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_7661 = torch.constant.float 0.000000e+00 - %false_7662 = torch.constant.bool false - %none_7663 = torch.constant.none - %6062:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6059, %6060, %6061, %float0.000000e00_7661, %false_7662, %368, %none_7663) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_7664 = torch.constant.int 1 - %int2_7665 = torch.constant.int 2 - %6063 = torch.aten.transpose.int %6062#0, %int1_7664, %int2_7665 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int32_7647 = torch.constant.int 32 + %int128_7648 = torch.constant.int 128 + %7090 = torch.prim.ListConstruct %int4_7646, %457, %int32_7647, %int128_7648 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7091 = torch.aten._unsafe_view %7089, %7090 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7091, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_7649 = torch.constant.int 1 + %int2_7650 = torch.constant.int 2 + %7092 = torch.aten.transpose.int %6974, %int1_7649, %int2_7650 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_7651 = torch.constant.int 1 + %int2_7652 = torch.constant.int 2 + %7093 = torch.aten.transpose.int %7085, %int1_7651, %int2_7652 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %7093, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_7653 = torch.constant.int 1 + %int2_7654 = torch.constant.int 2 + %7094 = torch.aten.transpose.int %7091, %int1_7653, %int2_7654 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %7094, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_7655 = torch.constant.float 0.000000e+00 + %false_7656 = torch.constant.bool false + %none_7657 = torch.constant.none + %7095:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%7092, %7093, %7094, %float0.000000e00_7655, %false_7656, %470, %none_7657) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_7658 = torch.constant.int 1 + %int2_7659 = torch.constant.int 2 + %7096 = torch.aten.transpose.int %7095#0, %int1_7658, %int2_7659 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_7660 = torch.constant.int 4 + %int1_7661 = torch.constant.int 1 + %int4096_7662 = torch.constant.int 4096 + %7097 = torch.prim.ListConstruct %int4_7660, %int1_7661, %int4096_7662 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7098 = torch.aten.view %7096, %7097 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_7663 = torch.constant.int -2 + %int-1_7664 = torch.constant.int -1 + %7099 = torch.aten.transpose.int %431, %int-2_7663, %int-1_7664 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_7665 = torch.constant.int 5 + %7100 = torch.prims.convert_element_type %7099, %int5_7665 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> %int4_7666 = torch.constant.int 4 - %int1_7667 = torch.constant.int 1 - %int4096_7668 = torch.constant.int 4096 - %6064 = torch.prim.ListConstruct %int4_7666, %int1_7667, %int4096_7668 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6065 = torch.aten.view %6063, %6064 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_7669 = torch.constant.int -2 - %int-1_7670 = torch.constant.int -1 - %6066 = torch.aten.transpose.int %304, %int-2_7669, %int-1_7670 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7671 = torch.constant.int 4 - %int4096_7672 = torch.constant.int 4096 - %6067 = torch.prim.ListConstruct %int4_7671, %int4096_7672 : (!torch.int, !torch.int) -> !torch.list - %6068 = torch.aten.view %6065, %6067 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6069 = torch.aten.mm %6068, %6066 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_7673 = torch.constant.int 4 - %int1_7674 = torch.constant.int 1 - %int4096_7675 = torch.constant.int 4096 - %6070 = torch.prim.ListConstruct %int4_7673, %int1_7674, %int4096_7675 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6071 = torch.aten.view %6069, %6070 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_7676 = torch.constant.int 1 - %6072 = torch.aten.add.Tensor %5899, %6071, %int1_7676 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_7677 = torch.constant.int 6 - %6073 = torch.prims.convert_element_type %6072, %int6_7677 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_7678 = torch.constant.int 2 - %6074 = torch.aten.pow.Tensor_Scalar %6073, %int2_7678 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_7679 = torch.constant.int -1 - %6075 = torch.prim.ListConstruct %int-1_7679 : (!torch.int) -> !torch.list - %true_7680 = torch.constant.bool true - %none_7681 = torch.constant.none - %6076 = torch.aten.mean.dim %6074, %6075, %true_7680, %none_7681 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_7682 = torch.constant.float 9.9999997473787516E-6 - %int1_7683 = torch.constant.int 1 - %6077 = torch.aten.add.Scalar %6076, %float9.999990e-06_7682, %int1_7683 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %6078 = torch.aten.rsqrt %6077 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %6079 = torch.aten.mul.Tensor %6073, %6078 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_7684 = torch.constant.int 5 - %6080 = torch.prims.convert_element_type %6079, %int5_7684 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %6081 = torch.aten.mul.Tensor %305, %6080 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_7685 = torch.constant.int 5 - %6082 = torch.prims.convert_element_type %6081, %int5_7685 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_7686 = torch.constant.int -2 - %int-1_7687 = torch.constant.int -1 - %6083 = torch.aten.transpose.int %306, %int-2_7686, %int-1_7687 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_7688 = torch.constant.int 4 - %int4096_7689 = torch.constant.int 4096 - %6084 = torch.prim.ListConstruct %int4_7688, %int4096_7689 : (!torch.int, !torch.int) -> !torch.list - %6085 = torch.aten.view %6082, %6084 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6086 = torch.aten.mm %6085, %6083 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_7690 = torch.constant.int 4 - %int1_7691 = torch.constant.int 1 - %int14336_7692 = torch.constant.int 14336 - %6087 = torch.prim.ListConstruct %int4_7690, %int1_7691, %int14336_7692 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6088 = torch.aten.view %6086, %6087 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %6089 = torch.aten.silu %6088 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_7693 = torch.constant.int -2 - %int-1_7694 = torch.constant.int -1 - %6090 = torch.aten.transpose.int %307, %int-2_7693, %int-1_7694 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_7695 = torch.constant.int 4 - %int4096_7696 = torch.constant.int 4096 - %6091 = torch.prim.ListConstruct %int4_7695, %int4096_7696 : (!torch.int, !torch.int) -> !torch.list - %6092 = torch.aten.view %6082, %6091 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6093 = torch.aten.mm %6092, %6090 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_7697 = torch.constant.int 4 - %int1_7698 = torch.constant.int 1 - %int14336_7699 = torch.constant.int 14336 - %6094 = torch.prim.ListConstruct %int4_7697, %int1_7698, %int14336_7699 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6095 = torch.aten.view %6093, %6094 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %6096 = torch.aten.mul.Tensor %6089, %6095 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_7700 = torch.constant.int -2 - %int-1_7701 = torch.constant.int -1 - %6097 = torch.aten.transpose.int %308, %int-2_7700, %int-1_7701 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4096_7667 = torch.constant.int 4096 + %7101 = torch.prim.ListConstruct %int4_7666, %int4096_7667 : (!torch.int, !torch.int) -> !torch.list + %7102 = torch.aten.view %7098, %7101 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %7103 = torch.aten.mm %7102, %7100 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_7668 = torch.constant.int 4 + %int1_7669 = torch.constant.int 1 + %int4096_7670 = torch.constant.int 4096 + %7104 = torch.prim.ListConstruct %int4_7668, %int1_7669, %int4096_7670 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7105 = torch.aten.view %7103, %7104 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_7671 = torch.constant.int 1 + %7106 = torch.aten.add.Tensor %6927, %7105, %int1_7671 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_7672 = torch.constant.int 6 + %7107 = torch.prims.convert_element_type %7106, %int6_7672 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_7673 = torch.constant.int 2 + %7108 = torch.aten.pow.Tensor_Scalar %7107, %int2_7673 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_7674 = torch.constant.int -1 + %7109 = torch.prim.ListConstruct %int-1_7674 : (!torch.int) -> !torch.list + %true_7675 = torch.constant.bool true + %none_7676 = torch.constant.none + %7110 = torch.aten.mean.dim %7108, %7109, %true_7675, %none_7676 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_7677 = torch.constant.float 9.9999997473787516E-6 + %int1_7678 = torch.constant.int 1 + %7111 = torch.aten.add.Scalar %7110, %float9.999990e-06_7677, %int1_7678 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %7112 = torch.aten.rsqrt %7111 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %7113 = torch.aten.mul.Tensor %7107, %7112 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_7679 = torch.constant.int 5 + %7114 = torch.prims.convert_element_type %7113, %int5_7679 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %7115 = torch.aten.mul.Tensor %432, %7114 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_7680 = torch.constant.int 5 + %7116 = torch.prims.convert_element_type %7115, %int5_7680 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_7681 = torch.constant.int -2 + %int-1_7682 = torch.constant.int -1 + %7117 = torch.aten.transpose.int %433, %int-2_7681, %int-1_7682 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_7683 = torch.constant.int 5 + %7118 = torch.prims.convert_element_type %7117, %int5_7683 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_7684 = torch.constant.int 4 + %int4096_7685 = torch.constant.int 4096 + %7119 = torch.prim.ListConstruct %int4_7684, %int4096_7685 : (!torch.int, !torch.int) -> !torch.list + %7120 = torch.aten.view %7116, %7119 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %7121 = torch.aten.mm %7120, %7118 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_7686 = torch.constant.int 4 + %int1_7687 = torch.constant.int 1 + %int14336_7688 = torch.constant.int 14336 + %7122 = torch.prim.ListConstruct %int4_7686, %int1_7687, %int14336_7688 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7123 = torch.aten.view %7121, %7122 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %7124 = torch.aten.silu %7123 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_7689 = torch.constant.int -2 + %int-1_7690 = torch.constant.int -1 + %7125 = torch.aten.transpose.int %434, %int-2_7689, %int-1_7690 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_7691 = torch.constant.int 5 + %7126 = torch.prims.convert_element_type %7125, %int5_7691 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_7692 = torch.constant.int 4 + %int4096_7693 = torch.constant.int 4096 + %7127 = torch.prim.ListConstruct %int4_7692, %int4096_7693 : (!torch.int, !torch.int) -> !torch.list + %7128 = torch.aten.view %7116, %7127 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %7129 = torch.aten.mm %7128, %7126 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_7694 = torch.constant.int 4 + %int1_7695 = torch.constant.int 1 + %int14336_7696 = torch.constant.int 14336 + %7130 = torch.prim.ListConstruct %int4_7694, %int1_7695, %int14336_7696 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7131 = torch.aten.view %7129, %7130 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %7132 = torch.aten.mul.Tensor %7124, %7131 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_7697 = torch.constant.int -2 + %int-1_7698 = torch.constant.int -1 + %7133 = torch.aten.transpose.int %435, %int-2_7697, %int-1_7698 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_7699 = torch.constant.int 5 + %7134 = torch.prims.convert_element_type %7133, %int5_7699 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_7700 = torch.constant.int 4 + %int14336_7701 = torch.constant.int 14336 + %7135 = torch.prim.ListConstruct %int4_7700, %int14336_7701 : (!torch.int, !torch.int) -> !torch.list + %7136 = torch.aten.view %7132, %7135 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %7137 = torch.aten.mm %7136, %7134 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> %int4_7702 = torch.constant.int 4 - %int14336_7703 = torch.constant.int 14336 - %6098 = torch.prim.ListConstruct %int4_7702, %int14336_7703 : (!torch.int, !torch.int) -> !torch.list - %6099 = torch.aten.view %6096, %6098 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %6100 = torch.aten.mm %6099, %6097 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_7704 = torch.constant.int 4 + %int1_7703 = torch.constant.int 1 + %int4096_7704 = torch.constant.int 4096 + %7138 = torch.prim.ListConstruct %int4_7702, %int1_7703, %int4096_7704 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7139 = torch.aten.view %7137, %7138 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> %int1_7705 = torch.constant.int 1 - %int4096_7706 = torch.constant.int 4096 - %6101 = torch.prim.ListConstruct %int4_7704, %int1_7705, %int4096_7706 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6102 = torch.aten.view %6100, %6101 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_7707 = torch.constant.int 1 - %6103 = torch.aten.add.Tensor %6072, %6102, %int1_7707 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_7708 = torch.constant.int 6 - %6104 = torch.prims.convert_element_type %6103, %int6_7708 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_7709 = torch.constant.int 2 - %6105 = torch.aten.pow.Tensor_Scalar %6104, %int2_7709 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_7710 = torch.constant.int -1 - %6106 = torch.prim.ListConstruct %int-1_7710 : (!torch.int) -> !torch.list - %true_7711 = torch.constant.bool true - %none_7712 = torch.constant.none - %6107 = torch.aten.mean.dim %6105, %6106, %true_7711, %none_7712 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_7713 = torch.constant.float 9.9999997473787516E-6 - %int1_7714 = torch.constant.int 1 - %6108 = torch.aten.add.Scalar %6107, %float9.999990e-06_7713, %int1_7714 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %6109 = torch.aten.rsqrt %6108 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %6110 = torch.aten.mul.Tensor %6104, %6109 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_7715 = torch.constant.int 5 - %6111 = torch.prims.convert_element_type %6110, %int5_7715 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %6112 = torch.aten.mul.Tensor %309, %6111 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_7716 = torch.constant.int 5 - %6113 = torch.prims.convert_element_type %6112, %int5_7716 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_7717 = torch.constant.int -2 - %int-1_7718 = torch.constant.int -1 - %6114 = torch.aten.transpose.int %310, %int-2_7717, %int-1_7718 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7719 = torch.constant.int 4 - %int4096_7720 = torch.constant.int 4096 - %6115 = torch.prim.ListConstruct %int4_7719, %int4096_7720 : (!torch.int, !torch.int) -> !torch.list - %6116 = torch.aten.view %6113, %6115 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6117 = torch.aten.mm %6116, %6114 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_7721 = torch.constant.int 4 - %int1_7722 = torch.constant.int 1 - %int4096_7723 = torch.constant.int 4096 - %6118 = torch.prim.ListConstruct %int4_7721, %int1_7722, %int4096_7723 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6119 = torch.aten.view %6117, %6118 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_7724 = torch.constant.int -2 - %int-1_7725 = torch.constant.int -1 - %6120 = torch.aten.transpose.int %311, %int-2_7724, %int-1_7725 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %7140 = torch.aten.add.Tensor %7106, %7139, %int1_7705 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_7706 = torch.constant.int 6 + %7141 = torch.prims.convert_element_type %7140, %int6_7706 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_7707 = torch.constant.int 2 + %7142 = torch.aten.pow.Tensor_Scalar %7141, %int2_7707 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_7708 = torch.constant.int -1 + %7143 = torch.prim.ListConstruct %int-1_7708 : (!torch.int) -> !torch.list + %true_7709 = torch.constant.bool true + %none_7710 = torch.constant.none + %7144 = torch.aten.mean.dim %7142, %7143, %true_7709, %none_7710 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_7711 = torch.constant.float 9.9999997473787516E-6 + %int1_7712 = torch.constant.int 1 + %7145 = torch.aten.add.Scalar %7144, %float9.999990e-06_7711, %int1_7712 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %7146 = torch.aten.rsqrt %7145 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %7147 = torch.aten.mul.Tensor %7141, %7146 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_7713 = torch.constant.int 5 + %7148 = torch.prims.convert_element_type %7147, %int5_7713 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %7149 = torch.aten.mul.Tensor %436, %7148 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_7714 = torch.constant.int 5 + %7150 = torch.prims.convert_element_type %7149, %int5_7714 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_7715 = torch.constant.int -2 + %int-1_7716 = torch.constant.int -1 + %7151 = torch.aten.transpose.int %437, %int-2_7715, %int-1_7716 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_7717 = torch.constant.int 5 + %7152 = torch.prims.convert_element_type %7151, %int5_7717 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_7718 = torch.constant.int 4 + %int4096_7719 = torch.constant.int 4096 + %7153 = torch.prim.ListConstruct %int4_7718, %int4096_7719 : (!torch.int, !torch.int) -> !torch.list + %7154 = torch.aten.view %7150, %7153 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %7155 = torch.aten.mm %7154, %7152 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_7720 = torch.constant.int 4 + %int1_7721 = torch.constant.int 1 + %int4096_7722 = torch.constant.int 4096 + %7156 = torch.prim.ListConstruct %int4_7720, %int1_7721, %int4096_7722 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7157 = torch.aten.view %7155, %7156 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_7723 = torch.constant.int -2 + %int-1_7724 = torch.constant.int -1 + %7158 = torch.aten.transpose.int %438, %int-2_7723, %int-1_7724 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_7725 = torch.constant.int 5 + %7159 = torch.prims.convert_element_type %7158, %int5_7725 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> %int4_7726 = torch.constant.int 4 %int4096_7727 = torch.constant.int 4096 - %6121 = torch.prim.ListConstruct %int4_7726, %int4096_7727 : (!torch.int, !torch.int) -> !torch.list - %6122 = torch.aten.view %6113, %6121 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6123 = torch.aten.mm %6122, %6120 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %7160 = torch.prim.ListConstruct %int4_7726, %int4096_7727 : (!torch.int, !torch.int) -> !torch.list + %7161 = torch.aten.view %7150, %7160 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %7162 = torch.aten.mm %7161, %7159 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> %int4_7728 = torch.constant.int 4 %int1_7729 = torch.constant.int 1 %int1024_7730 = torch.constant.int 1024 - %6124 = torch.prim.ListConstruct %int4_7728, %int1_7729, %int1024_7730 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6125 = torch.aten.view %6123, %6124 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %7163 = torch.prim.ListConstruct %int4_7728, %int1_7729, %int1024_7730 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7164 = torch.aten.view %7162, %7163 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> %int-2_7731 = torch.constant.int -2 %int-1_7732 = torch.constant.int -1 - %6126 = torch.aten.transpose.int %312, %int-2_7731, %int-1_7732 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_7733 = torch.constant.int 4 - %int4096_7734 = torch.constant.int 4096 - %6127 = torch.prim.ListConstruct %int4_7733, %int4096_7734 : (!torch.int, !torch.int) -> !torch.list - %6128 = torch.aten.view %6113, %6127 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6129 = torch.aten.mm %6128, %6126 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_7735 = torch.constant.int 4 - %int1_7736 = torch.constant.int 1 - %int1024_7737 = torch.constant.int 1024 - %6130 = torch.prim.ListConstruct %int4_7735, %int1_7736, %int1024_7737 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6131 = torch.aten.view %6129, %6130 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_7738 = torch.constant.int 4 - %int1_7739 = torch.constant.int 1 - %int32_7740 = torch.constant.int 32 - %int128_7741 = torch.constant.int 128 - %6132 = torch.prim.ListConstruct %int4_7738, %int1_7739, %int32_7740, %int128_7741 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6133 = torch.aten.view %6119, %6132 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_7742 = torch.constant.int 4 - %int1_7743 = torch.constant.int 1 - %int8_7744 = torch.constant.int 8 - %int128_7745 = torch.constant.int 128 - %6134 = torch.prim.ListConstruct %int4_7742, %int1_7743, %int8_7744, %int128_7745 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6135 = torch.aten.view %6125, %6134 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_7746 = torch.constant.int 4 - %int1_7747 = torch.constant.int 1 - %int8_7748 = torch.constant.int 8 - %int128_7749 = torch.constant.int 128 - %6136 = torch.prim.ListConstruct %int4_7746, %int1_7747, %int8_7748, %int128_7749 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6137 = torch.aten.view %6131, %6136 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_7750 = torch.constant.int 6 - %6138 = torch.prims.convert_element_type %6133, %int6_7750 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %6139 = torch_c.to_builtin_tensor %6138 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %6140 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %6141 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%6139, %6140) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %6142 = torch_c.from_builtin_tensor %6141 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_7751 = torch.constant.int 5 - %6143 = torch.prims.convert_element_type %6142, %int5_7751 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_7752 = torch.constant.int 6 - %6144 = torch.prims.convert_element_type %6135, %int6_7752 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %6145 = torch_c.to_builtin_tensor %6144 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %6146 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %6147 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%6145, %6146) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %6148 = torch_c.from_builtin_tensor %6147 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_7753 = torch.constant.int 5 - %6149 = torch.prims.convert_element_type %6148, %int5_7753 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_7754 = torch.constant.int 32 - %6150 = torch.aten.floor_divide.Scalar %arg2, %int32_7754 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_7755 = torch.constant.int 1 - %6151 = torch.aten.unsqueeze %6150, %int1_7755 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %7165 = torch.aten.transpose.int %439, %int-2_7731, %int-1_7732 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int5_7733 = torch.constant.int 5 + %7166 = torch.prims.convert_element_type %7165, %int5_7733 : !torch.vtensor<[4096,1024],f16>, !torch.int -> !torch.vtensor<[4096,1024],f16> + %int4_7734 = torch.constant.int 4 + %int4096_7735 = torch.constant.int 4096 + %7167 = torch.prim.ListConstruct %int4_7734, %int4096_7735 : (!torch.int, !torch.int) -> !torch.list + %7168 = torch.aten.view %7150, %7167 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %7169 = torch.aten.mm %7168, %7166 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> + %int4_7736 = torch.constant.int 4 + %int1_7737 = torch.constant.int 1 + %int1024_7738 = torch.constant.int 1024 + %7170 = torch.prim.ListConstruct %int4_7736, %int1_7737, %int1024_7738 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7171 = torch.aten.view %7169, %7170 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> + %int4_7739 = torch.constant.int 4 + %int1_7740 = torch.constant.int 1 + %int32_7741 = torch.constant.int 32 + %int128_7742 = torch.constant.int 128 + %7172 = torch.prim.ListConstruct %int4_7739, %int1_7740, %int32_7741, %int128_7742 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7173 = torch.aten.view %7157, %7172 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> + %int4_7743 = torch.constant.int 4 + %int1_7744 = torch.constant.int 1 + %int8_7745 = torch.constant.int 8 + %int128_7746 = torch.constant.int 128 + %7174 = torch.prim.ListConstruct %int4_7743, %int1_7744, %int8_7745, %int128_7746 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7175 = torch.aten.view %7164, %7174 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int4_7747 = torch.constant.int 4 + %int1_7748 = torch.constant.int 1 + %int8_7749 = torch.constant.int 8 + %int128_7750 = torch.constant.int 128 + %7176 = torch.prim.ListConstruct %int4_7747, %int1_7748, %int8_7749, %int128_7750 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7177 = torch.aten.view %7171, %7176 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> + %int1_7751 = torch.constant.int 1 + %int2_7752 = torch.constant.int 2 + %7178 = torch.aten.transpose.int %7173, %int1_7751, %int2_7752 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %7179 = torch.aten.mul.Tensor %7178, %527 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int3_7753 = torch.constant.int 3 + %int0_7754 = torch.constant.int 0 + %int64_7755 = torch.constant.int 64 %int1_7756 = torch.constant.int 1 - %false_7757 = torch.constant.bool false - %6152 = torch.aten.gather %arg3, %int1_7756, %6151, %false_7757 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_7758 = torch.constant.int 32 - %6153 = torch.aten.remainder.Scalar %arg2, %int32_7758 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_7759 = torch.constant.int 1 - %6154 = torch.aten.unsqueeze %6153, %int1_7759 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_7760 = torch.constant.none - %6155 = torch.aten.clone %313, %none_7760 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_7761 = torch.constant.int 0 - %6156 = torch.aten.unsqueeze %6155, %int0_7761 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_7762 = torch.constant.int 4 + %7180 = torch.aten.slice.Tensor %7178, %int3_7753, %int0_7754, %int64_7755, %int1_7756 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %int3_7757 = torch.constant.int 3 + %int64_7758 = torch.constant.int 64 + %int9223372036854775807_7759 = torch.constant.int 9223372036854775807 + %int1_7760 = torch.constant.int 1 + %7181 = torch.aten.slice.Tensor %7178, %int3_7757, %int64_7758, %int9223372036854775807_7759, %int1_7760 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,64],f16> + %7182 = torch.aten.neg %7181 : !torch.vtensor<[4,32,1,64],f16> -> !torch.vtensor<[4,32,1,64],f16> + %7183 = torch.prim.ListConstruct %7182, %7180 : (!torch.vtensor<[4,32,1,64],f16>, !torch.vtensor<[4,32,1,64],f16>) -> !torch.list + %int-1_7761 = torch.constant.int -1 + %7184 = torch.aten.cat %7183, %int-1_7761 : !torch.list, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %7185 = torch.aten.mul.Tensor %7184, %531 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,32,1,128],f16> + %int1_7762 = torch.constant.int 1 + %7186 = torch.aten.add.Tensor %7179, %7185, %int1_7762 : !torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1,128],f16>, !torch.int -> !torch.vtensor<[4,32,1,128],f16> %int1_7763 = torch.constant.int 1 - %6157 = torch.prim.ListConstruct %int4_7762, %int1_7763 : (!torch.int, !torch.int) -> !torch.list - %int1_7764 = torch.constant.int 1 + %int2_7764 = torch.constant.int 2 + %7187 = torch.aten.transpose.int %7186, %int1_7763, %int2_7764 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> %int1_7765 = torch.constant.int 1 - %6158 = torch.prim.ListConstruct %int1_7764, %int1_7765 : (!torch.int, !torch.int) -> !torch.list - %int4_7766 = torch.constant.int 4 - %int0_7767 = torch.constant.int 0 - %cpu_7768 = torch.constant.device "cpu" - %false_7769 = torch.constant.bool false - %6159 = torch.aten.empty_strided %6157, %6158, %int4_7766, %int0_7767, %cpu_7768, %false_7769 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int28 = torch.constant.int 28 - %6160 = torch.aten.fill.Scalar %6159, %int28 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_7770 = torch.constant.int 4 - %int1_7771 = torch.constant.int 1 - %6161 = torch.prim.ListConstruct %int4_7770, %int1_7771 : (!torch.int, !torch.int) -> !torch.list - %6162 = torch.aten.repeat %6156, %6161 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_7772 = torch.constant.int 32 - %6163 = torch.aten.mul.Scalar %6152, %int32_7772 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7773 = torch.constant.int 1 - %6164 = torch.aten.add.Tensor %6163, %6160, %int1_7773 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_7774 = torch.constant.int 2 - %6165 = torch.aten.mul.Scalar %6164, %int2_7774 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7775 = torch.constant.int 1 - %6166 = torch.aten.add.Tensor %6165, %6162, %int1_7775 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_7776 = torch.constant.int 32 - %6167 = torch.aten.mul.Scalar %6166, %int32_7776 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int2_7766 = torch.constant.int 2 + %7188 = torch.aten.transpose.int %7175, %int1_7765, %int2_7766 : !torch.vtensor<[4,1,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %7189 = torch.aten.mul.Tensor %7188, %527 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int3_7767 = torch.constant.int 3 + %int0_7768 = torch.constant.int 0 + %int64_7769 = torch.constant.int 64 + %int1_7770 = torch.constant.int 1 + %7190 = torch.aten.slice.Tensor %7188, %int3_7767, %int0_7768, %int64_7769, %int1_7770 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %int3_7771 = torch.constant.int 3 + %int64_7772 = torch.constant.int 64 + %int9223372036854775807_7773 = torch.constant.int 9223372036854775807 + %int1_7774 = torch.constant.int 1 + %7191 = torch.aten.slice.Tensor %7188, %int3_7771, %int64_7772, %int9223372036854775807_7773, %int1_7774 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,8,1,64],f16> + %7192 = torch.aten.neg %7191 : !torch.vtensor<[4,8,1,64],f16> -> !torch.vtensor<[4,8,1,64],f16> + %7193 = torch.prim.ListConstruct %7192, %7190 : (!torch.vtensor<[4,8,1,64],f16>, !torch.vtensor<[4,8,1,64],f16>) -> !torch.list + %int-1_7775 = torch.constant.int -1 + %7194 = torch.aten.cat %7193, %int-1_7775 : !torch.list, !torch.int -> !torch.vtensor<[4,8,1,128],f16> + %7195 = torch.aten.mul.Tensor %7194, %531 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,1,1,128],f16> -> !torch.vtensor<[4,8,1,128],f16> + %int1_7776 = torch.constant.int 1 + %7196 = torch.aten.add.Tensor %7189, %7195, %int1_7776 : !torch.vtensor<[4,8,1,128],f16>, !torch.vtensor<[4,8,1,128],f16>, !torch.int -> !torch.vtensor<[4,8,1,128],f16> %int1_7777 = torch.constant.int 1 - %6168 = torch.aten.add.Tensor %6167, %6154, %int1_7777 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_7778 = torch.constant.int 32 - %int2_7779 = torch.constant.int 2 - %int32_7780 = torch.constant.int 32 - %int8_7781 = torch.constant.int 8 - %int128_7782 = torch.constant.int 128 - %6169 = torch.prim.ListConstruct %437, %int32_7778, %int2_7779, %int32_7780, %int8_7781, %int128_7782 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6170 = torch.aten.view %6006, %6169 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6170, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_7783 = torch.constant.int 32 - %6171 = torch.aten.mul.int %437, %int32_7783 : !torch.int, !torch.int -> !torch.int - %int2_7784 = torch.constant.int 2 - %6172 = torch.aten.mul.int %6171, %int2_7784 : !torch.int, !torch.int -> !torch.int - %int32_7785 = torch.constant.int 32 - %6173 = torch.aten.mul.int %6172, %int32_7785 : !torch.int, !torch.int -> !torch.int - %int8_7786 = torch.constant.int 8 - %int128_7787 = torch.constant.int 128 - %6174 = torch.prim.ListConstruct %6173, %int8_7786, %int128_7787 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6175 = torch.aten.view %6170, %6174 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6175, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %6176 = torch.prim.ListConstruct %6168 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_7788 = torch.constant.bool false - %6177 = torch.aten.index_put %6175, %6176, %6149, %false_7788 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6177, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_7789 = torch.constant.int 32 - %int2_7790 = torch.constant.int 2 - %int32_7791 = torch.constant.int 32 - %int8_7792 = torch.constant.int 8 - %int128_7793 = torch.constant.int 128 - %6178 = torch.prim.ListConstruct %437, %int32_7789, %int2_7790, %int32_7791, %int8_7792, %int128_7793 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6179 = torch.aten.view %6177, %6178 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6179, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_7794 = torch.constant.int 2097152 - %6180 = torch.prim.ListConstruct %437, %int2097152_7794 : (!torch.int, !torch.int) -> !torch.list - %6181 = torch.aten.view %6179, %6180 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %6181, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_7795 = torch.constant.int 32 - %int2_7796 = torch.constant.int 2 - %int32_7797 = torch.constant.int 32 - %int8_7798 = torch.constant.int 8 - %int128_7799 = torch.constant.int 128 - %6182 = torch.prim.ListConstruct %437, %int32_7795, %int2_7796, %int32_7797, %int8_7798, %int128_7799 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6183 = torch.aten.view %6181, %6182 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6183, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_7800 = torch.constant.int 8 - %int128_7801 = torch.constant.int 128 - %6184 = torch.prim.ListConstruct %6173, %int8_7800, %int128_7801 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6185 = torch.aten.view %6183, %6184 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6185, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> + %int2_7778 = torch.constant.int 2 + %7197 = torch.aten.transpose.int %7196, %int1_7777, %int2_7778 : !torch.vtensor<[4,8,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_7779 = torch.constant.int 32 + %7198 = torch.aten.floor_divide.Scalar %arg2, %int32_7779 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int1_7780 = torch.constant.int 1 + %7199 = torch.aten.unsqueeze %7198, %int1_7780 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + %int1_7781 = torch.constant.int 1 + %false_7782 = torch.constant.bool false + %7200 = torch.aten.gather %arg3, %int1_7781, %7199, %false_7782 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> + %int4_7783 = torch.constant.int 4 + %int1_7784 = torch.constant.int 1 + %int1_7785 = torch.constant.int 1 + %7201 = torch.prim.ListConstruct %int4_7783, %int1_7784, %int1_7785 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7202 = torch.aten.view %7200, %7201 : !torch.vtensor<[4,1],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int32_7786 = torch.constant.int 32 + %7203 = torch.aten.remainder.Scalar %arg2, %int32_7786 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %int4_7787 = torch.constant.int 4 + %int1_7788 = torch.constant.int 1 + %int1_7789 = torch.constant.int 1 + %7204 = torch.prim.ListConstruct %int4_7787, %int1_7788, %int1_7789 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7205 = torch.aten.view %7203, %7204 : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[4,1,1],si64> + %int8_7790 = torch.constant.int 8 + %none_7791 = torch.constant.none + %none_7792 = torch.constant.none + %cpu_7793 = torch.constant.device "cpu" + %false_7794 = torch.constant.bool false + %7206 = torch.aten.arange %int8_7790, %none_7791, %none_7792, %cpu_7793, %false_7794 : !torch.int, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[8],si64> + %int1_7795 = torch.constant.int 1 + %int1_7796 = torch.constant.int 1 + %int8_7797 = torch.constant.int 8 + %7207 = torch.prim.ListConstruct %int1_7795, %int1_7796, %int8_7797 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7208 = torch.aten.view %7206, %7207 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,1,8],si64> + %none_7798 = torch.constant.none + %7209 = torch.aten.clone %440, %none_7798 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %7210 = torch.aten.detach %7209 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7211 = torch.aten.detach %7210 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7212 = torch.aten.detach %7211 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_7799 = torch.constant.int 1 + %int1_7800 = torch.constant.int 1 + %int1_7801 = torch.constant.int 1 + %7213 = torch.prim.ListConstruct %int1_7799, %int1_7800, %int1_7801 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7214 = torch.aten.view %7212, %7213 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> %int32_7802 = torch.constant.int 32 - %6186 = torch.aten.floor_divide.Scalar %arg2, %int32_7802 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %7215 = torch.aten.mul.Scalar %7202, %int32_7802 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int31 = torch.constant.int 31 %int1_7803 = torch.constant.int 1 - %6187 = torch.aten.unsqueeze %6186, %int1_7803 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7804 = torch.constant.int 1 - %false_7805 = torch.constant.bool false - %6188 = torch.aten.gather %arg3, %int1_7804, %6187, %false_7805 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_7806 = torch.constant.int 32 - %6189 = torch.aten.remainder.Scalar %arg2, %int32_7806 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + %7216 = torch.aten.add.Scalar %7215, %int31, %int1_7803 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_7804 = torch.constant.int 2 + %7217 = torch.aten.mul.Scalar %7216, %int2_7804 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_7805 = torch.constant.int 1 + %7218 = torch.aten.add.Tensor %7217, %7214, %int1_7805 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_7806 = torch.constant.int 8 + %7219 = torch.aten.mul.Scalar %7218, %int8_7806 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> %int1_7807 = torch.constant.int 1 - %6190 = torch.aten.unsqueeze %6189, %int1_7807 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_7808 = torch.constant.none - %6191 = torch.aten.clone %314, %none_7808 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_7809 = torch.constant.int 0 - %6192 = torch.aten.unsqueeze %6191, %int0_7809 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_7810 = torch.constant.int 4 - %int1_7811 = torch.constant.int 1 - %6193 = torch.prim.ListConstruct %int4_7810, %int1_7811 : (!torch.int, !torch.int) -> !torch.list - %int1_7812 = torch.constant.int 1 - %int1_7813 = torch.constant.int 1 - %6194 = torch.prim.ListConstruct %int1_7812, %int1_7813 : (!torch.int, !torch.int) -> !torch.list - %int4_7814 = torch.constant.int 4 - %int0_7815 = torch.constant.int 0 - %cpu_7816 = torch.constant.device "cpu" + %7220 = torch.aten.add.Tensor %7219, %7208, %int1_7807 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int32_7808 = torch.constant.int 32 + %7221 = torch.aten.mul.Scalar %7220, %int32_7808 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int1_7809 = torch.constant.int 1 + %7222 = torch.aten.add.Tensor %7221, %7205, %int1_7809 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_7810 = torch.constant.int 5 + %7223 = torch.prims.convert_element_type %7197, %int5_7810 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %int32_7811 = torch.constant.int 32 + %int2_7812 = torch.constant.int 2 + %int8_7813 = torch.constant.int 8 + %int32_7814 = torch.constant.int 32 + %int128_7815 = torch.constant.int 128 + %7224 = torch.prim.ListConstruct %456, %int32_7811, %int2_7812, %int8_7813, %int32_7814, %int128_7815 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7225 = torch.aten.view %7045, %7224 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7225, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_7816 = torch.constant.int 128 + %7226 = torch.prim.ListConstruct %596, %int128_7816 : (!torch.int, !torch.int) -> !torch.list + %7227 = torch.aten.view %7225, %7226 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7227, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %7228 = torch.prim.ListConstruct %7222 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> %false_7817 = torch.constant.bool false - %6195 = torch.aten.empty_strided %6193, %6194, %int4_7814, %int0_7815, %cpu_7816, %false_7817 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int28_7818 = torch.constant.int 28 - %6196 = torch.aten.fill.Scalar %6195, %int28_7818 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_7819 = torch.constant.int 4 - %int1_7820 = torch.constant.int 1 - %6197 = torch.prim.ListConstruct %int4_7819, %int1_7820 : (!torch.int, !torch.int) -> !torch.list - %6198 = torch.aten.repeat %6192, %6197 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> + %7229 = torch.aten.index_put %7227, %7228, %7223, %false_7817 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7229, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_7818 = torch.constant.int 32 + %int2_7819 = torch.constant.int 2 + %int8_7820 = torch.constant.int 8 %int32_7821 = torch.constant.int 32 - %6199 = torch.aten.mul.Scalar %6188, %int32_7821 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7822 = torch.constant.int 1 - %6200 = torch.aten.add.Tensor %6199, %6196, %int1_7822 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_7823 = torch.constant.int 2 - %6201 = torch.aten.mul.Scalar %6200, %int2_7823 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7824 = torch.constant.int 1 - %6202 = torch.aten.add.Tensor %6201, %6198, %int1_7824 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_7825 = torch.constant.int 32 - %6203 = torch.aten.mul.Scalar %6202, %int32_7825 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_7826 = torch.constant.int 1 - %6204 = torch.aten.add.Tensor %6203, %6190, %int1_7826 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %6205 = torch.prim.ListConstruct %6204 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_7827 = torch.constant.bool false - %6206 = torch.aten.index_put %6185, %6205, %6137, %false_7827 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6206, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_7828 = torch.constant.int 32 - %int2_7829 = torch.constant.int 2 - %int32_7830 = torch.constant.int 32 - %int8_7831 = torch.constant.int 8 - %int128_7832 = torch.constant.int 128 - %6207 = torch.prim.ListConstruct %437, %int32_7828, %int2_7829, %int32_7830, %int8_7831, %int128_7832 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6208 = torch.aten.view %6206, %6207 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6208, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_7833 = torch.constant.int 2097152 - %6209 = torch.prim.ListConstruct %437, %int2097152_7833 : (!torch.int, !torch.int) -> !torch.list - %6210 = torch.aten.view %6208, %6209 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %6210, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_7834 = torch.constant.int 4 - %6211 = torch.prim.ListConstruct %int4_7834, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_7835 = torch.constant.int 1 - %6212 = torch.prim.ListConstruct %358, %int1_7835 : (!torch.int, !torch.int) -> !torch.list - %int4_7836 = torch.constant.int 4 - %int0_7837 = torch.constant.int 0 - %cpu_7838 = torch.constant.device "cpu" - %false_7839 = torch.constant.bool false - %6213 = torch.aten.empty_strided %6211, %6212, %int4_7836, %int0_7837, %cpu_7838, %false_7839 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6213, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int28_7840 = torch.constant.int 28 - %6214 = torch.aten.fill.Scalar %6213, %int28_7840 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6214, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %int128_7822 = torch.constant.int 128 + %7230 = torch.prim.ListConstruct %456, %int32_7818, %int2_7819, %int8_7820, %int32_7821, %int128_7822 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7231 = torch.aten.view %7229, %7230 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7231, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_7823 = torch.constant.int 2097152 + %7232 = torch.prim.ListConstruct %456, %int2097152_7823 : (!torch.int, !torch.int) -> !torch.list + %7233 = torch.aten.view %7231, %7232 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.bind_symbolic_shape %7233, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %int32_7824 = torch.constant.int 32 + %int2_7825 = torch.constant.int 2 + %int8_7826 = torch.constant.int 8 + %int32_7827 = torch.constant.int 32 + %int128_7828 = torch.constant.int 128 + %7234 = torch.prim.ListConstruct %456, %int32_7824, %int2_7825, %int8_7826, %int32_7827, %int128_7828 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7235 = torch.aten.view %7233, %7234 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7235, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int128_7829 = torch.constant.int 128 + %7236 = torch.prim.ListConstruct %596, %int128_7829 : (!torch.int, !torch.int) -> !torch.list + %7237 = torch.aten.view %7235, %7236 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7237, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %none_7830 = torch.constant.none + %7238 = torch.aten.clone %441, %none_7830 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %7239 = torch.aten.detach %7238 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7240 = torch.aten.detach %7239 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7241 = torch.aten.detach %7240 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int1_7831 = torch.constant.int 1 + %int1_7832 = torch.constant.int 1 + %int1_7833 = torch.constant.int 1 + %7242 = torch.prim.ListConstruct %int1_7831, %int1_7832, %int1_7833 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7243 = torch.aten.view %7241, %7242 : !torch.vtensor<[],si64>, !torch.list -> !torch.vtensor<[1,1,1],si64> + %int32_7834 = torch.constant.int 32 + %7244 = torch.aten.mul.Scalar %7202, %int32_7834 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int31_7835 = torch.constant.int 31 + %int1_7836 = torch.constant.int 1 + %7245 = torch.aten.add.Scalar %7244, %int31_7835, %int1_7836 : !torch.vtensor<[4,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int2_7837 = torch.constant.int 2 + %7246 = torch.aten.mul.Scalar %7245, %int2_7837 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_7838 = torch.constant.int 1 + %7247 = torch.aten.add.Tensor %7246, %7243, %int1_7838 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int8_7839 = torch.constant.int 8 + %7248 = torch.aten.mul.Scalar %7247, %int8_7839 : !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,1],si64> + %int1_7840 = torch.constant.int 1 + %7249 = torch.aten.add.Tensor %7248, %7208, %int1_7840 : !torch.vtensor<[4,1,1],si64>, !torch.vtensor<[1,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int32_7841 = torch.constant.int 32 - %6215 = torch.aten.mul.Scalar %arg3, %int32_7841 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6215, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> + %7250 = torch.aten.mul.Scalar %7249, %int32_7841 : !torch.vtensor<[4,1,8],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> %int1_7842 = torch.constant.int 1 - %6216 = torch.aten.add.Tensor %6215, %6214, %int1_7842 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6216, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_7843 = torch.constant.int 4 - %6217 = torch.aten.mul.int %int4_7843, %358 : !torch.int, !torch.int -> !torch.int - %6218 = torch.prim.ListConstruct %6217 : (!torch.int) -> !torch.list - %6219 = torch.aten.view %6216, %6218 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %6219, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_7844 = torch.constant.int 32 - %int2_7845 = torch.constant.int 2 - %int32_7846 = torch.constant.int 32 + %7251 = torch.aten.add.Tensor %7250, %7205, %int1_7842 : !torch.vtensor<[4,1,8],si64>, !torch.vtensor<[4,1,1],si64>, !torch.int -> !torch.vtensor<[4,1,8],si64> + %int5_7843 = torch.constant.int 5 + %7252 = torch.prims.convert_element_type %7177, %int5_7843 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> + %7253 = torch.prim.ListConstruct %7251 : (!torch.vtensor<[4,1,8],si64>) -> !torch.list> + %false_7844 = torch.constant.bool false + %7254 = torch.aten.index_put %7237, %7253, %7252, %false_7844 : !torch.vtensor<[?,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,128],f16> + torch.bind_symbolic_shape %7254, [%454], affine_map<()[s0] -> (s0 * 16384, 128)> : !torch.vtensor<[?,128],f16> + %int32_7845 = torch.constant.int 32 + %int2_7846 = torch.constant.int 2 %int8_7847 = torch.constant.int 8 - %int128_7848 = torch.constant.int 128 - %6220 = torch.prim.ListConstruct %437, %int32_7844, %int2_7845, %int32_7846, %int8_7847, %int128_7848 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6221 = torch.aten.view %6210, %6220 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6221, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_7849 = torch.constant.int 32 - %6222 = torch.aten.mul.int %437, %int32_7849 : !torch.int, !torch.int -> !torch.int - %int2_7850 = torch.constant.int 2 - %int32_7851 = torch.constant.int 32 - %int8_7852 = torch.constant.int 8 - %int128_7853 = torch.constant.int 128 - %6223 = torch.prim.ListConstruct %6222, %int2_7850, %int32_7851, %int8_7852, %int128_7853 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6224 = torch.aten.view %6221, %6223 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %6224, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_7854 = torch.constant.int 0 - %6225 = torch.aten.index_select %6224, %int0_7854, %6219 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %6225, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_7855 = torch.constant.int 4 - %int2_7856 = torch.constant.int 2 + %int32_7848 = torch.constant.int 32 + %int128_7849 = torch.constant.int 128 + %7255 = torch.prim.ListConstruct %456, %int32_7845, %int2_7846, %int8_7847, %int32_7848, %int128_7849 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7256 = torch.aten.view %7254, %7255 : !torch.vtensor<[?,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7256, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %int2097152_7850 = torch.constant.int 2097152 + %7257 = torch.prim.ListConstruct %456, %int2097152_7850 : (!torch.int, !torch.int) -> !torch.list + %7258 = torch.aten.view %7256, %7257 : !torch.vtensor<[?,32,2,8,32,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> + torch.overwrite.tensor.contents %7258 overwrites %arg4 : !torch.vtensor<[?,2097152],f16>, !torch.tensor<[?,2097152],f16> + torch.bind_symbolic_shape %7258, [%454], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> + %none_7851 = torch.constant.none + %7259 = torch.aten.clone %442, %none_7851 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %7260 = torch.aten.detach %7259 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7261 = torch.aten.detach %7260 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7262 = torch.aten.detach %7261 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_7852 = torch.constant.none + %7263 = torch.aten.clone %443, %none_7852 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %7264 = torch.aten.detach %7263 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7265 = torch.aten.detach %7264 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7266 = torch.aten.detach %7265 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %none_7853 = torch.constant.none + %7267 = torch.aten.clone %444, %none_7853 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> + %7268 = torch.aten.detach %7267 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7269 = torch.aten.detach %7268 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %7270 = torch.aten.detach %7269 : !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %int32_7854 = torch.constant.int 32 + %int2_7855 = torch.constant.int 2 + %int8_7856 = torch.constant.int 8 %int32_7857 = torch.constant.int 32 - %int8_7858 = torch.constant.int 8 - %int128_7859 = torch.constant.int 128 - %6226 = torch.prim.ListConstruct %int4_7855, %358, %int2_7856, %int32_7857, %int8_7858, %int128_7859 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6227 = torch.aten.view %6225, %6226 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6227, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_7860 = torch.constant.int 0 - %int0_7861 = torch.constant.int 0 - %int9223372036854775807_7862 = torch.constant.int 9223372036854775807 - %int1_7863 = torch.constant.int 1 - %6228 = torch.aten.slice.Tensor %6227, %int0_7860, %int0_7861, %int9223372036854775807_7862, %int1_7863 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6228, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_7864 = torch.constant.int 1 + %int128_7858 = torch.constant.int 128 + %7271 = torch.prim.ListConstruct %456, %int32_7854, %int2_7855, %int8_7856, %int32_7857, %int128_7858 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7272 = torch.aten.view %7258, %7271 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,8,32,128],f16> + torch.bind_symbolic_shape %7272, [%454], affine_map<()[s0] -> (s0, 32, 2, 8, 32, 128)> : !torch.vtensor<[?,32,2,8,32,128],f16> + %7273 = torch_c.to_builtin_tensor %7272 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %7274 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_7859 = tensor.cast %7274 : tensor<4x?xi64> to tensor + %7275 = torch_c.to_builtin_tensor %7262 : !torch.vtensor<[],si64> -> tensor + %7276 = torch_c.to_builtin_tensor %7266 : !torch.vtensor<[],si64> -> tensor + %7277 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%7273, %cast_7859, %7275, %7276) : (tensor, tensor, tensor, tensor) -> tensor + %cast_7860 = tensor.cast %7277 : tensor to tensor<4x?x8x32x128xf16> + %7278 = torch_c.from_builtin_tensor %cast_7860 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %7278, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %7279 = torch_c.to_builtin_tensor %7272 : !torch.vtensor<[?,32,2,8,32,128],f16> -> tensor + %7280 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[4,?],si64> -> tensor<4x?xi64> + %cast_7861 = tensor.cast %7280 : tensor<4x?xi64> to tensor + %7281 = torch_c.to_builtin_tensor %7262 : !torch.vtensor<[],si64> -> tensor + %7282 = torch_c.to_builtin_tensor %7270 : !torch.vtensor<[],si64> -> tensor + %7283 = util.call @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%7279, %cast_7861, %7281, %7282) : (tensor, tensor, tensor, tensor) -> tensor + %cast_7862 = tensor.cast %7283 : tensor to tensor<4x?x8x32x128xf16> + %7284 = torch_c.from_builtin_tensor %cast_7862 : tensor<4x?x8x32x128xf16> -> !torch.vtensor<[4,?,8,32,128],f16> + torch.bind_symbolic_shape %7284, [%453], affine_map<()[s0] -> (4, s0, 8, 32, 128)> : !torch.vtensor<[4,?,8,32,128],f16> + %int2_7863 = torch.constant.int 2 + %int3_7864 = torch.constant.int 3 + %7285 = torch.aten.transpose.int %7278, %int2_7863, %int3_7864 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %7285, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> %int0_7865 = torch.constant.int 0 - %int9223372036854775807_7866 = torch.constant.int 9223372036854775807 - %int1_7867 = torch.constant.int 1 - %6229 = torch.aten.slice.Tensor %6228, %int1_7864, %int0_7865, %int9223372036854775807_7866, %int1_7867 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6229, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_7868 = torch.constant.int 2 - %int0_7869 = torch.constant.int 0 - %6230 = torch.aten.select.int %6229, %int2_7868, %int0_7869 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6230, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_7870 = torch.constant.int 32 - %6231 = torch.aten.mul.int %358, %int32_7870 : !torch.int, !torch.int -> !torch.int - %int2_7871 = torch.constant.int 2 - %int0_7872 = torch.constant.int 0 - %int1_7873 = torch.constant.int 1 - %6232 = torch.aten.slice.Tensor %6230, %int2_7871, %int0_7872, %6231, %int1_7873 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6232, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_7874 = torch.constant.int 0 - %6233 = torch.aten.clone %6232, %int0_7874 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6233, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_7875 = torch.constant.int 1 - %6234 = torch.aten.size.int %6229, %int1_7875 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_7876 = torch.constant.int 32 - %6235 = torch.aten.mul.int %6234, %int32_7876 : !torch.int, !torch.int -> !torch.int - %int4_7877 = torch.constant.int 4 - %int8_7878 = torch.constant.int 8 + %7286 = torch.aten.clone %7285, %int0_7865 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %7286, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_7866 = torch.constant.int 4 + %int8_7867 = torch.constant.int 8 + %int128_7868 = torch.constant.int 128 + %7287 = torch.prim.ListConstruct %int4_7866, %457, %int8_7867, %int128_7868 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7288 = torch.aten._unsafe_view %7286, %7287 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7288, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int2_7869 = torch.constant.int 2 + %int3_7870 = torch.constant.int 3 + %7289 = torch.aten.transpose.int %7284, %int2_7869, %int3_7870 : !torch.vtensor<[4,?,8,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %7289, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int0_7871 = torch.constant.int 0 + %7290 = torch.aten.clone %7289, %int0_7871 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> + torch.bind_symbolic_shape %7290, [%453], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> + %int4_7872 = torch.constant.int 4 + %int8_7873 = torch.constant.int 8 + %int128_7874 = torch.constant.int 128 + %7291 = torch.prim.ListConstruct %int4_7872, %457, %int8_7873, %int128_7874 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7292 = torch.aten._unsafe_view %7290, %7291 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> + torch.bind_symbolic_shape %7292, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> + %int-2_7875 = torch.constant.int -2 + %7293 = torch.aten.unsqueeze %7288, %int-2_7875 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %7293, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_7876 = torch.constant.int 4 + %int8_7877 = torch.constant.int 8 + %int4_7878 = torch.constant.int 4 %int128_7879 = torch.constant.int 128 - %6236 = torch.prim.ListConstruct %int4_7877, %6235, %int8_7878, %int128_7879 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6237 = torch.aten._unsafe_view %6233, %6236 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6237, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_7880 = torch.constant.int 0 + %7294 = torch.prim.ListConstruct %int4_7876, %457, %int8_7877, %int4_7878, %int128_7879 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_7880 = torch.constant.bool false + %7295 = torch.aten.expand %7293, %7294, %false_7880 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7295, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> %int0_7881 = torch.constant.int 0 - %int9223372036854775807_7882 = torch.constant.int 9223372036854775807 - %int1_7883 = torch.constant.int 1 - %6238 = torch.aten.slice.Tensor %6237, %int0_7880, %int0_7881, %int9223372036854775807_7882, %int1_7883 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6238, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_7884 = torch.constant.int 0 - %int0_7885 = torch.constant.int 0 - %int9223372036854775807_7886 = torch.constant.int 9223372036854775807 - %int1_7887 = torch.constant.int 1 - %6239 = torch.aten.slice.Tensor %6227, %int0_7884, %int0_7885, %int9223372036854775807_7886, %int1_7887 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6239, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_7888 = torch.constant.int 1 - %int0_7889 = torch.constant.int 0 - %int9223372036854775807_7890 = torch.constant.int 9223372036854775807 - %int1_7891 = torch.constant.int 1 - %6240 = torch.aten.slice.Tensor %6239, %int1_7888, %int0_7889, %int9223372036854775807_7890, %int1_7891 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6240, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_7892 = torch.constant.int 2 - %int1_7893 = torch.constant.int 1 - %6241 = torch.aten.select.int %6240, %int2_7892, %int1_7893 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6241, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_7894 = torch.constant.int 2 - %int0_7895 = torch.constant.int 0 - %int1_7896 = torch.constant.int 1 - %6242 = torch.aten.slice.Tensor %6241, %int2_7894, %int0_7895, %6231, %int1_7896 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6242, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_7897 = torch.constant.int 0 - %6243 = torch.aten.clone %6242, %int0_7897 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6243, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_7898 = torch.constant.int 1 - %6244 = torch.aten.size.int %6240, %int1_7898 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_7899 = torch.constant.int 32 - %6245 = torch.aten.mul.int %6244, %int32_7899 : !torch.int, !torch.int -> !torch.int - %int4_7900 = torch.constant.int 4 - %int8_7901 = torch.constant.int 8 - %int128_7902 = torch.constant.int 128 - %6246 = torch.prim.ListConstruct %int4_7900, %6245, %int8_7901, %int128_7902 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6247 = torch.aten._unsafe_view %6243, %6246 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6247, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_7903 = torch.constant.int 0 - %int0_7904 = torch.constant.int 0 - %int9223372036854775807_7905 = torch.constant.int 9223372036854775807 - %int1_7906 = torch.constant.int 1 - %6248 = torch.aten.slice.Tensor %6247, %int0_7903, %int0_7904, %int9223372036854775807_7905, %int1_7906 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6248, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_7907 = torch.constant.int -2 - %6249 = torch.aten.unsqueeze %6238, %int-2_7907 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6249, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_7908 = torch.constant.int 1 - %6250 = torch.aten.size.int %6237, %int1_7908 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_7909 = torch.constant.int 4 - %int8_7910 = torch.constant.int 8 - %int4_7911 = torch.constant.int 4 - %int128_7912 = torch.constant.int 128 - %6251 = torch.prim.ListConstruct %int4_7909, %6250, %int8_7910, %int4_7911, %int128_7912 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7913 = torch.constant.bool false - %6252 = torch.aten.expand %6249, %6251, %false_7913 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6252, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_7914 = torch.constant.int 0 - %6253 = torch.aten.clone %6252, %int0_7914 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6253, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7915 = torch.constant.int 4 - %int32_7916 = torch.constant.int 32 - %int128_7917 = torch.constant.int 128 - %6254 = torch.prim.ListConstruct %int4_7915, %6250, %int32_7916, %int128_7917 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6255 = torch.aten._unsafe_view %6253, %6254 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6255, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_7918 = torch.constant.int -2 - %6256 = torch.aten.unsqueeze %6248, %int-2_7918 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6256, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_7919 = torch.constant.int 1 - %6257 = torch.aten.size.int %6247, %int1_7919 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_7920 = torch.constant.int 4 - %int8_7921 = torch.constant.int 8 - %int4_7922 = torch.constant.int 4 - %int128_7923 = torch.constant.int 128 - %6258 = torch.prim.ListConstruct %int4_7920, %6257, %int8_7921, %int4_7922, %int128_7923 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_7924 = torch.constant.bool false - %6259 = torch.aten.expand %6256, %6258, %false_7924 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6259, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_7925 = torch.constant.int 0 - %6260 = torch.aten.clone %6259, %int0_7925 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6260, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_7926 = torch.constant.int 4 - %int32_7927 = torch.constant.int 32 - %int128_7928 = torch.constant.int 128 - %6261 = torch.prim.ListConstruct %int4_7926, %6257, %int32_7927, %int128_7928 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6262 = torch.aten._unsafe_view %6260, %6261 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6262, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_7929 = torch.constant.int 1 - %int2_7930 = torch.constant.int 2 - %6263 = torch.aten.transpose.int %6143, %int1_7929, %int2_7930 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_7931 = torch.constant.int 1 - %int2_7932 = torch.constant.int 2 - %6264 = torch.aten.transpose.int %6255, %int1_7931, %int2_7932 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6264, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %7296 = torch.aten.clone %7295, %int0_7881 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7296, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_7882 = torch.constant.int 4 + %int32_7883 = torch.constant.int 32 + %int128_7884 = torch.constant.int 128 + %7297 = torch.prim.ListConstruct %int4_7882, %457, %int32_7883, %int128_7884 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7298 = torch.aten._unsafe_view %7296, %7297 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7298, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int-2_7885 = torch.constant.int -2 + %7299 = torch.aten.unsqueeze %7292, %int-2_7885 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> + torch.bind_symbolic_shape %7299, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> + %int4_7886 = torch.constant.int 4 + %int8_7887 = torch.constant.int 8 + %int4_7888 = torch.constant.int 4 + %int128_7889 = torch.constant.int 128 + %7300 = torch.prim.ListConstruct %int4_7886, %457, %int8_7887, %int4_7888, %int128_7889 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %false_7890 = torch.constant.bool false + %7301 = torch.aten.expand %7299, %7300, %false_7890 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7301, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int0_7891 = torch.constant.int 0 + %7302 = torch.aten.clone %7301, %int0_7891 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> + torch.bind_symbolic_shape %7302, [%453], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> + %int4_7892 = torch.constant.int 4 + %int32_7893 = torch.constant.int 32 + %int128_7894 = torch.constant.int 128 + %7303 = torch.prim.ListConstruct %int4_7892, %457, %int32_7893, %int128_7894 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %7304 = torch.aten._unsafe_view %7302, %7303 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> + torch.bind_symbolic_shape %7304, [%453], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> + %int1_7895 = torch.constant.int 1 + %int2_7896 = torch.constant.int 2 + %7305 = torch.aten.transpose.int %7187, %int1_7895, %int2_7896 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> + %int1_7897 = torch.constant.int 1 + %int2_7898 = torch.constant.int 2 + %7306 = torch.aten.transpose.int %7298, %int1_7897, %int2_7898 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %7306, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %int1_7899 = torch.constant.int 1 + %int2_7900 = torch.constant.int 2 + %7307 = torch.aten.transpose.int %7304, %int1_7899, %int2_7900 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> + torch.bind_symbolic_shape %7307, [%453], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> + %float0.000000e00_7901 = torch.constant.float 0.000000e+00 + %false_7902 = torch.constant.bool false + %none_7903 = torch.constant.none + %7308:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%7305, %7306, %7307, %float0.000000e00_7901, %false_7902, %470, %none_7903) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) + %int1_7904 = torch.constant.int 1 + %int2_7905 = torch.constant.int 2 + %7309 = torch.aten.transpose.int %7308#0, %int1_7904, %int2_7905 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int4_7906 = torch.constant.int 4 + %int1_7907 = torch.constant.int 1 + %int4096_7908 = torch.constant.int 4096 + %7310 = torch.prim.ListConstruct %int4_7906, %int1_7907, %int4096_7908 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7311 = torch.aten.view %7309, %7310 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int-2_7909 = torch.constant.int -2 + %int-1_7910 = torch.constant.int -1 + %7312 = torch.aten.transpose.int %445, %int-2_7909, %int-1_7910 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int5_7911 = torch.constant.int 5 + %7313 = torch.prims.convert_element_type %7312, %int5_7911 : !torch.vtensor<[4096,4096],f16>, !torch.int -> !torch.vtensor<[4096,4096],f16> + %int4_7912 = torch.constant.int 4 + %int4096_7913 = torch.constant.int 4096 + %7314 = torch.prim.ListConstruct %int4_7912, %int4096_7913 : (!torch.int, !torch.int) -> !torch.list + %7315 = torch.aten.view %7311, %7314 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %7316 = torch.aten.mm %7315, %7313 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_7914 = torch.constant.int 4 + %int1_7915 = torch.constant.int 1 + %int4096_7916 = torch.constant.int 4096 + %7317 = torch.prim.ListConstruct %int4_7914, %int1_7915, %int4096_7916 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7318 = torch.aten.view %7316, %7317 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_7917 = torch.constant.int 1 + %7319 = torch.aten.add.Tensor %7140, %7318, %int1_7917 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_7918 = torch.constant.int 6 + %7320 = torch.prims.convert_element_type %7319, %int6_7918 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_7919 = torch.constant.int 2 + %7321 = torch.aten.pow.Tensor_Scalar %7320, %int2_7919 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_7920 = torch.constant.int -1 + %7322 = torch.prim.ListConstruct %int-1_7920 : (!torch.int) -> !torch.list + %true_7921 = torch.constant.bool true + %none_7922 = torch.constant.none + %7323 = torch.aten.mean.dim %7321, %7322, %true_7921, %none_7922 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_7923 = torch.constant.float 9.9999997473787516E-6 + %int1_7924 = torch.constant.int 1 + %7324 = torch.aten.add.Scalar %7323, %float9.999990e-06_7923, %int1_7924 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %7325 = torch.aten.rsqrt %7324 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %7326 = torch.aten.mul.Tensor %7320, %7325 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> + %int5_7925 = torch.constant.int 5 + %7327 = torch.prims.convert_element_type %7326, %int5_7925 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %7328 = torch.aten.mul.Tensor %446, %7327 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_7926 = torch.constant.int 5 + %7329 = torch.prims.convert_element_type %7328, %int5_7926 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_7927 = torch.constant.int -2 + %int-1_7928 = torch.constant.int -1 + %7330 = torch.aten.transpose.int %447, %int-2_7927, %int-1_7928 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_7929 = torch.constant.int 5 + %7331 = torch.prims.convert_element_type %7330, %int5_7929 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_7930 = torch.constant.int 4 + %int4096_7931 = torch.constant.int 4096 + %7332 = torch.prim.ListConstruct %int4_7930, %int4096_7931 : (!torch.int, !torch.int) -> !torch.list + %7333 = torch.aten.view %7329, %7332 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %7334 = torch.aten.mm %7333, %7331 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %int4_7932 = torch.constant.int 4 %int1_7933 = torch.constant.int 1 - %int2_7934 = torch.constant.int 2 - %6265 = torch.aten.transpose.int %6262, %int1_7933, %int2_7934 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6265, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_7935 = torch.constant.float 0.000000e+00 - %false_7936 = torch.constant.bool false - %none_7937 = torch.constant.none - %6266:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6263, %6264, %6265, %float0.000000e00_7935, %false_7936, %368, %none_7937) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_7938 = torch.constant.int 1 - %int2_7939 = torch.constant.int 2 - %6267 = torch.aten.transpose.int %6266#0, %int1_7938, %int2_7939 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> + %int14336_7934 = torch.constant.int 14336 + %7335 = torch.prim.ListConstruct %int4_7932, %int1_7933, %int14336_7934 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7336 = torch.aten.view %7334, %7335 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %7337 = torch.aten.silu %7336 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> + %int-2_7935 = torch.constant.int -2 + %int-1_7936 = torch.constant.int -1 + %7338 = torch.aten.transpose.int %448, %int-2_7935, %int-1_7936 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int5_7937 = torch.constant.int 5 + %7339 = torch.prims.convert_element_type %7338, %int5_7937 : !torch.vtensor<[4096,14336],f16>, !torch.int -> !torch.vtensor<[4096,14336],f16> + %int4_7938 = torch.constant.int 4 + %int4096_7939 = torch.constant.int 4096 + %7340 = torch.prim.ListConstruct %int4_7938, %int4096_7939 : (!torch.int, !torch.int) -> !torch.list + %7341 = torch.aten.view %7329, %7340 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %7342 = torch.aten.mm %7341, %7339 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> %int4_7940 = torch.constant.int 4 %int1_7941 = torch.constant.int 1 - %int4096_7942 = torch.constant.int 4096 - %6268 = torch.prim.ListConstruct %int4_7940, %int1_7941, %int4096_7942 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6269 = torch.aten.view %6267, %6268 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int14336_7942 = torch.constant.int 14336 + %7343 = torch.prim.ListConstruct %int4_7940, %int1_7941, %int14336_7942 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7344 = torch.aten.view %7342, %7343 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> + %7345 = torch.aten.mul.Tensor %7337, %7344 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> %int-2_7943 = torch.constant.int -2 %int-1_7944 = torch.constant.int -1 - %6270 = torch.aten.transpose.int %315, %int-2_7943, %int-1_7944 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7945 = torch.constant.int 4 - %int4096_7946 = torch.constant.int 4096 - %6271 = torch.prim.ListConstruct %int4_7945, %int4096_7946 : (!torch.int, !torch.int) -> !torch.list - %6272 = torch.aten.view %6269, %6271 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6273 = torch.aten.mm %6272, %6270 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_7947 = torch.constant.int 4 - %int1_7948 = torch.constant.int 1 - %int4096_7949 = torch.constant.int 4096 - %6274 = torch.prim.ListConstruct %int4_7947, %int1_7948, %int4096_7949 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6275 = torch.aten.view %6273, %6274 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_7950 = torch.constant.int 1 - %6276 = torch.aten.add.Tensor %6103, %6275, %int1_7950 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_7951 = torch.constant.int 6 - %6277 = torch.prims.convert_element_type %6276, %int6_7951 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_7952 = torch.constant.int 2 - %6278 = torch.aten.pow.Tensor_Scalar %6277, %int2_7952 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_7953 = torch.constant.int -1 - %6279 = torch.prim.ListConstruct %int-1_7953 : (!torch.int) -> !torch.list - %true_7954 = torch.constant.bool true - %none_7955 = torch.constant.none - %6280 = torch.aten.mean.dim %6278, %6279, %true_7954, %none_7955 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_7956 = torch.constant.float 9.9999997473787516E-6 - %int1_7957 = torch.constant.int 1 - %6281 = torch.aten.add.Scalar %6280, %float9.999990e-06_7956, %int1_7957 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %6282 = torch.aten.rsqrt %6281 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %6283 = torch.aten.mul.Tensor %6277, %6282 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_7958 = torch.constant.int 5 - %6284 = torch.prims.convert_element_type %6283, %int5_7958 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %6285 = torch.aten.mul.Tensor %316, %6284 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %7346 = torch.aten.transpose.int %449, %int-2_7943, %int-1_7944 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int5_7945 = torch.constant.int 5 + %7347 = torch.prims.convert_element_type %7346, %int5_7945 : !torch.vtensor<[14336,4096],f16>, !torch.int -> !torch.vtensor<[14336,4096],f16> + %int4_7946 = torch.constant.int 4 + %int14336_7947 = torch.constant.int 14336 + %7348 = torch.prim.ListConstruct %int4_7946, %int14336_7947 : (!torch.int, !torch.int) -> !torch.list + %7349 = torch.aten.view %7345, %7348 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> + %7350 = torch.aten.mm %7349, %7347 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> + %int4_7948 = torch.constant.int 4 + %int1_7949 = torch.constant.int 1 + %int4096_7950 = torch.constant.int 4096 + %7351 = torch.prim.ListConstruct %int4_7948, %int1_7949, %int4096_7950 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7352 = torch.aten.view %7350, %7351 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> + %int1_7951 = torch.constant.int 1 + %7353 = torch.aten.add.Tensor %7319, %7352, %int1_7951 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int6_7952 = torch.constant.int 6 + %7354 = torch.prims.convert_element_type %7353, %int6_7952 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int2_7953 = torch.constant.int 2 + %7355 = torch.aten.pow.Tensor_Scalar %7354, %int2_7953 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> + %int-1_7954 = torch.constant.int -1 + %7356 = torch.prim.ListConstruct %int-1_7954 : (!torch.int) -> !torch.list + %true_7955 = torch.constant.bool true + %none_7956 = torch.constant.none + %7357 = torch.aten.mean.dim %7355, %7356, %true_7955, %none_7956 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> + %float9.999990e-06_7957 = torch.constant.float 9.9999997473787516E-6 + %int1_7958 = torch.constant.int 1 + %7358 = torch.aten.add.Scalar %7357, %float9.999990e-06_7957, %int1_7958 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> + %7359 = torch.aten.rsqrt %7358 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> + %7360 = torch.aten.mul.Tensor %7354, %7359 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> %int5_7959 = torch.constant.int 5 - %6286 = torch.prims.convert_element_type %6285, %int5_7959 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_7960 = torch.constant.int -2 - %int-1_7961 = torch.constant.int -1 - %6287 = torch.aten.transpose.int %317, %int-2_7960, %int-1_7961 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_7962 = torch.constant.int 4 - %int4096_7963 = torch.constant.int 4096 - %6288 = torch.prim.ListConstruct %int4_7962, %int4096_7963 : (!torch.int, !torch.int) -> !torch.list - %6289 = torch.aten.view %6286, %6288 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6290 = torch.aten.mm %6289, %6287 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> + %7361 = torch.prims.convert_element_type %7360, %int5_7959 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %7362 = torch.aten.mul.Tensor %450, %7361 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> + %int5_7960 = torch.constant.int 5 + %7363 = torch.prims.convert_element_type %7362, %int5_7960 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> + %int-2_7961 = torch.constant.int -2 + %int-1_7962 = torch.constant.int -1 + %7364 = torch.aten.transpose.int %451, %int-2_7961, %int-1_7962 : !torch.vtensor<[128256,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,128256],f16> + %int5_7963 = torch.constant.int 5 + %7365 = torch.prims.convert_element_type %7364, %int5_7963 : !torch.vtensor<[4096,128256],f16>, !torch.int -> !torch.vtensor<[4096,128256],f16> %int4_7964 = torch.constant.int 4 - %int1_7965 = torch.constant.int 1 - %int14336_7966 = torch.constant.int 14336 - %6291 = torch.prim.ListConstruct %int4_7964, %int1_7965, %int14336_7966 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6292 = torch.aten.view %6290, %6291 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %6293 = torch.aten.silu %6292 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_7967 = torch.constant.int -2 - %int-1_7968 = torch.constant.int -1 - %6294 = torch.aten.transpose.int %318, %int-2_7967, %int-1_7968 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_7969 = torch.constant.int 4 - %int4096_7970 = torch.constant.int 4096 - %6295 = torch.prim.ListConstruct %int4_7969, %int4096_7970 : (!torch.int, !torch.int) -> !torch.list - %6296 = torch.aten.view %6286, %6295 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6297 = torch.aten.mm %6296, %6294 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_7971 = torch.constant.int 4 - %int1_7972 = torch.constant.int 1 - %int14336_7973 = torch.constant.int 14336 - %6298 = torch.prim.ListConstruct %int4_7971, %int1_7972, %int14336_7973 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6299 = torch.aten.view %6297, %6298 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %6300 = torch.aten.mul.Tensor %6293, %6299 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_7974 = torch.constant.int -2 - %int-1_7975 = torch.constant.int -1 - %6301 = torch.aten.transpose.int %319, %int-2_7974, %int-1_7975 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_7976 = torch.constant.int 4 - %int14336_7977 = torch.constant.int 14336 - %6302 = torch.prim.ListConstruct %int4_7976, %int14336_7977 : (!torch.int, !torch.int) -> !torch.list - %6303 = torch.aten.view %6300, %6302 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %6304 = torch.aten.mm %6303, %6301 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_7978 = torch.constant.int 4 - %int1_7979 = torch.constant.int 1 - %int4096_7980 = torch.constant.int 4096 - %6305 = torch.prim.ListConstruct %int4_7978, %int1_7979, %int4096_7980 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6306 = torch.aten.view %6304, %6305 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_7981 = torch.constant.int 1 - %6307 = torch.aten.add.Tensor %6276, %6306, %int1_7981 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_7982 = torch.constant.int 6 - %6308 = torch.prims.convert_element_type %6307, %int6_7982 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_7983 = torch.constant.int 2 - %6309 = torch.aten.pow.Tensor_Scalar %6308, %int2_7983 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_7984 = torch.constant.int -1 - %6310 = torch.prim.ListConstruct %int-1_7984 : (!torch.int) -> !torch.list - %true_7985 = torch.constant.bool true - %none_7986 = torch.constant.none - %6311 = torch.aten.mean.dim %6309, %6310, %true_7985, %none_7986 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_7987 = torch.constant.float 9.9999997473787516E-6 - %int1_7988 = torch.constant.int 1 - %6312 = torch.aten.add.Scalar %6311, %float9.999990e-06_7987, %int1_7988 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %6313 = torch.aten.rsqrt %6312 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %6314 = torch.aten.mul.Tensor %6308, %6313 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_7989 = torch.constant.int 5 - %6315 = torch.prims.convert_element_type %6314, %int5_7989 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %6316 = torch.aten.mul.Tensor %320, %6315 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_7990 = torch.constant.int 5 - %6317 = torch.prims.convert_element_type %6316, %int5_7990 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_7991 = torch.constant.int -2 - %int-1_7992 = torch.constant.int -1 - %6318 = torch.aten.transpose.int %321, %int-2_7991, %int-1_7992 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_7993 = torch.constant.int 4 - %int4096_7994 = torch.constant.int 4096 - %6319 = torch.prim.ListConstruct %int4_7993, %int4096_7994 : (!torch.int, !torch.int) -> !torch.list - %6320 = torch.aten.view %6317, %6319 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6321 = torch.aten.mm %6320, %6318 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_7995 = torch.constant.int 4 - %int1_7996 = torch.constant.int 1 - %int4096_7997 = torch.constant.int 4096 - %6322 = torch.prim.ListConstruct %int4_7995, %int1_7996, %int4096_7997 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6323 = torch.aten.view %6321, %6322 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_7998 = torch.constant.int -2 - %int-1_7999 = torch.constant.int -1 - %6324 = torch.aten.transpose.int %322, %int-2_7998, %int-1_7999 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_8000 = torch.constant.int 4 - %int4096_8001 = torch.constant.int 4096 - %6325 = torch.prim.ListConstruct %int4_8000, %int4096_8001 : (!torch.int, !torch.int) -> !torch.list - %6326 = torch.aten.view %6317, %6325 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6327 = torch.aten.mm %6326, %6324 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_8002 = torch.constant.int 4 - %int1_8003 = torch.constant.int 1 - %int1024_8004 = torch.constant.int 1024 - %6328 = torch.prim.ListConstruct %int4_8002, %int1_8003, %int1024_8004 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6329 = torch.aten.view %6327, %6328 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_8005 = torch.constant.int -2 - %int-1_8006 = torch.constant.int -1 - %6330 = torch.aten.transpose.int %323, %int-2_8005, %int-1_8006 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_8007 = torch.constant.int 4 - %int4096_8008 = torch.constant.int 4096 - %6331 = torch.prim.ListConstruct %int4_8007, %int4096_8008 : (!torch.int, !torch.int) -> !torch.list - %6332 = torch.aten.view %6317, %6331 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6333 = torch.aten.mm %6332, %6330 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_8009 = torch.constant.int 4 - %int1_8010 = torch.constant.int 1 - %int1024_8011 = torch.constant.int 1024 - %6334 = torch.prim.ListConstruct %int4_8009, %int1_8010, %int1024_8011 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6335 = torch.aten.view %6333, %6334 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_8012 = torch.constant.int 4 - %int1_8013 = torch.constant.int 1 - %int32_8014 = torch.constant.int 32 - %int128_8015 = torch.constant.int 128 - %6336 = torch.prim.ListConstruct %int4_8012, %int1_8013, %int32_8014, %int128_8015 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6337 = torch.aten.view %6323, %6336 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_8016 = torch.constant.int 4 - %int1_8017 = torch.constant.int 1 - %int8_8018 = torch.constant.int 8 - %int128_8019 = torch.constant.int 128 - %6338 = torch.prim.ListConstruct %int4_8016, %int1_8017, %int8_8018, %int128_8019 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6339 = torch.aten.view %6329, %6338 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_8020 = torch.constant.int 4 - %int1_8021 = torch.constant.int 1 - %int8_8022 = torch.constant.int 8 - %int128_8023 = torch.constant.int 128 - %6340 = torch.prim.ListConstruct %int4_8020, %int1_8021, %int8_8022, %int128_8023 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6341 = torch.aten.view %6335, %6340 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_8024 = torch.constant.int 6 - %6342 = torch.prims.convert_element_type %6337, %int6_8024 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %6343 = torch_c.to_builtin_tensor %6342 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %6344 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %6345 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%6343, %6344) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %6346 = torch_c.from_builtin_tensor %6345 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_8025 = torch.constant.int 5 - %6347 = torch.prims.convert_element_type %6346, %int5_8025 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_8026 = torch.constant.int 6 - %6348 = torch.prims.convert_element_type %6339, %int6_8026 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %6349 = torch_c.to_builtin_tensor %6348 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %6350 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %6351 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%6349, %6350) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %6352 = torch_c.from_builtin_tensor %6351 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_8027 = torch.constant.int 5 - %6353 = torch.prims.convert_element_type %6352, %int5_8027 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_8028 = torch.constant.int 32 - %6354 = torch.aten.floor_divide.Scalar %arg2, %int32_8028 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_8029 = torch.constant.int 1 - %6355 = torch.aten.unsqueeze %6354, %int1_8029 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8030 = torch.constant.int 1 - %false_8031 = torch.constant.bool false - %6356 = torch.aten.gather %arg3, %int1_8030, %6355, %false_8031 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_8032 = torch.constant.int 32 - %6357 = torch.aten.remainder.Scalar %arg2, %int32_8032 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_8033 = torch.constant.int 1 - %6358 = torch.aten.unsqueeze %6357, %int1_8033 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_8034 = torch.constant.none - %6359 = torch.aten.clone %324, %none_8034 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_8035 = torch.constant.int 0 - %6360 = torch.aten.unsqueeze %6359, %int0_8035 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_8036 = torch.constant.int 4 - %int1_8037 = torch.constant.int 1 - %6361 = torch.prim.ListConstruct %int4_8036, %int1_8037 : (!torch.int, !torch.int) -> !torch.list - %int1_8038 = torch.constant.int 1 - %int1_8039 = torch.constant.int 1 - %6362 = torch.prim.ListConstruct %int1_8038, %int1_8039 : (!torch.int, !torch.int) -> !torch.list - %int4_8040 = torch.constant.int 4 - %int0_8041 = torch.constant.int 0 - %cpu_8042 = torch.constant.device "cpu" - %false_8043 = torch.constant.bool false - %6363 = torch.aten.empty_strided %6361, %6362, %int4_8040, %int0_8041, %cpu_8042, %false_8043 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int29 = torch.constant.int 29 - %6364 = torch.aten.fill.Scalar %6363, %int29 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_8044 = torch.constant.int 4 - %int1_8045 = torch.constant.int 1 - %6365 = torch.prim.ListConstruct %int4_8044, %int1_8045 : (!torch.int, !torch.int) -> !torch.list - %6366 = torch.aten.repeat %6360, %6365 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_8046 = torch.constant.int 32 - %6367 = torch.aten.mul.Scalar %6356, %int32_8046 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8047 = torch.constant.int 1 - %6368 = torch.aten.add.Tensor %6367, %6364, %int1_8047 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_8048 = torch.constant.int 2 - %6369 = torch.aten.mul.Scalar %6368, %int2_8048 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8049 = torch.constant.int 1 - %6370 = torch.aten.add.Tensor %6369, %6366, %int1_8049 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_8050 = torch.constant.int 32 - %6371 = torch.aten.mul.Scalar %6370, %int32_8050 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8051 = torch.constant.int 1 - %6372 = torch.aten.add.Tensor %6371, %6358, %int1_8051 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_8052 = torch.constant.int 32 - %int2_8053 = torch.constant.int 2 - %int32_8054 = torch.constant.int 32 - %int8_8055 = torch.constant.int 8 - %int128_8056 = torch.constant.int 128 - %6373 = torch.prim.ListConstruct %437, %int32_8052, %int2_8053, %int32_8054, %int8_8055, %int128_8056 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6374 = torch.aten.view %6210, %6373 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6374, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_8057 = torch.constant.int 32 - %6375 = torch.aten.mul.int %437, %int32_8057 : !torch.int, !torch.int -> !torch.int - %int2_8058 = torch.constant.int 2 - %6376 = torch.aten.mul.int %6375, %int2_8058 : !torch.int, !torch.int -> !torch.int - %int32_8059 = torch.constant.int 32 - %6377 = torch.aten.mul.int %6376, %int32_8059 : !torch.int, !torch.int -> !torch.int - %int8_8060 = torch.constant.int 8 - %int128_8061 = torch.constant.int 128 - %6378 = torch.prim.ListConstruct %6377, %int8_8060, %int128_8061 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6379 = torch.aten.view %6374, %6378 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6379, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %6380 = torch.prim.ListConstruct %6372 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_8062 = torch.constant.bool false - %6381 = torch.aten.index_put %6379, %6380, %6353, %false_8062 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6381, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_8063 = torch.constant.int 32 - %int2_8064 = torch.constant.int 2 - %int32_8065 = torch.constant.int 32 - %int8_8066 = torch.constant.int 8 - %int128_8067 = torch.constant.int 128 - %6382 = torch.prim.ListConstruct %437, %int32_8063, %int2_8064, %int32_8065, %int8_8066, %int128_8067 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6383 = torch.aten.view %6381, %6382 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6383, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_8068 = torch.constant.int 2097152 - %6384 = torch.prim.ListConstruct %437, %int2097152_8068 : (!torch.int, !torch.int) -> !torch.list - %6385 = torch.aten.view %6383, %6384 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %6385, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_8069 = torch.constant.int 32 - %int2_8070 = torch.constant.int 2 - %int32_8071 = torch.constant.int 32 - %int8_8072 = torch.constant.int 8 - %int128_8073 = torch.constant.int 128 - %6386 = torch.prim.ListConstruct %437, %int32_8069, %int2_8070, %int32_8071, %int8_8072, %int128_8073 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6387 = torch.aten.view %6385, %6386 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6387, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_8074 = torch.constant.int 8 - %int128_8075 = torch.constant.int 128 - %6388 = torch.prim.ListConstruct %6377, %int8_8074, %int128_8075 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6389 = torch.aten.view %6387, %6388 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6389, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_8076 = torch.constant.int 32 - %6390 = torch.aten.floor_divide.Scalar %arg2, %int32_8076 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_8077 = torch.constant.int 1 - %6391 = torch.aten.unsqueeze %6390, %int1_8077 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8078 = torch.constant.int 1 - %false_8079 = torch.constant.bool false - %6392 = torch.aten.gather %arg3, %int1_8078, %6391, %false_8079 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_8080 = torch.constant.int 32 - %6393 = torch.aten.remainder.Scalar %arg2, %int32_8080 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_8081 = torch.constant.int 1 - %6394 = torch.aten.unsqueeze %6393, %int1_8081 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_8082 = torch.constant.none - %6395 = torch.aten.clone %325, %none_8082 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_8083 = torch.constant.int 0 - %6396 = torch.aten.unsqueeze %6395, %int0_8083 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_8084 = torch.constant.int 4 - %int1_8085 = torch.constant.int 1 - %6397 = torch.prim.ListConstruct %int4_8084, %int1_8085 : (!torch.int, !torch.int) -> !torch.list - %int1_8086 = torch.constant.int 1 - %int1_8087 = torch.constant.int 1 - %6398 = torch.prim.ListConstruct %int1_8086, %int1_8087 : (!torch.int, !torch.int) -> !torch.list - %int4_8088 = torch.constant.int 4 - %int0_8089 = torch.constant.int 0 - %cpu_8090 = torch.constant.device "cpu" - %false_8091 = torch.constant.bool false - %6399 = torch.aten.empty_strided %6397, %6398, %int4_8088, %int0_8089, %cpu_8090, %false_8091 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int29_8092 = torch.constant.int 29 - %6400 = torch.aten.fill.Scalar %6399, %int29_8092 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_8093 = torch.constant.int 4 - %int1_8094 = torch.constant.int 1 - %6401 = torch.prim.ListConstruct %int4_8093, %int1_8094 : (!torch.int, !torch.int) -> !torch.list - %6402 = torch.aten.repeat %6396, %6401 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_8095 = torch.constant.int 32 - %6403 = torch.aten.mul.Scalar %6392, %int32_8095 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8096 = torch.constant.int 1 - %6404 = torch.aten.add.Tensor %6403, %6400, %int1_8096 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_8097 = torch.constant.int 2 - %6405 = torch.aten.mul.Scalar %6404, %int2_8097 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8098 = torch.constant.int 1 - %6406 = torch.aten.add.Tensor %6405, %6402, %int1_8098 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_8099 = torch.constant.int 32 - %6407 = torch.aten.mul.Scalar %6406, %int32_8099 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8100 = torch.constant.int 1 - %6408 = torch.aten.add.Tensor %6407, %6394, %int1_8100 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %6409 = torch.prim.ListConstruct %6408 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_8101 = torch.constant.bool false - %6410 = torch.aten.index_put %6389, %6409, %6341, %false_8101 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6410, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_8102 = torch.constant.int 32 - %int2_8103 = torch.constant.int 2 - %int32_8104 = torch.constant.int 32 - %int8_8105 = torch.constant.int 8 - %int128_8106 = torch.constant.int 128 - %6411 = torch.prim.ListConstruct %437, %int32_8102, %int2_8103, %int32_8104, %int8_8105, %int128_8106 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6412 = torch.aten.view %6410, %6411 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6412, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_8107 = torch.constant.int 2097152 - %6413 = torch.prim.ListConstruct %437, %int2097152_8107 : (!torch.int, !torch.int) -> !torch.list - %6414 = torch.aten.view %6412, %6413 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %6414, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_8108 = torch.constant.int 4 - %6415 = torch.prim.ListConstruct %int4_8108, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_8109 = torch.constant.int 1 - %6416 = torch.prim.ListConstruct %358, %int1_8109 : (!torch.int, !torch.int) -> !torch.list - %int4_8110 = torch.constant.int 4 - %int0_8111 = torch.constant.int 0 - %cpu_8112 = torch.constant.device "cpu" - %false_8113 = torch.constant.bool false - %6417 = torch.aten.empty_strided %6415, %6416, %int4_8110, %int0_8111, %cpu_8112, %false_8113 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6417, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int29_8114 = torch.constant.int 29 - %6418 = torch.aten.fill.Scalar %6417, %int29_8114 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6418, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_8115 = torch.constant.int 32 - %6419 = torch.aten.mul.Scalar %arg3, %int32_8115 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6419, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_8116 = torch.constant.int 1 - %6420 = torch.aten.add.Tensor %6419, %6418, %int1_8116 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6420, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_8117 = torch.constant.int 4 - %6421 = torch.aten.mul.int %int4_8117, %358 : !torch.int, !torch.int -> !torch.int - %6422 = torch.prim.ListConstruct %6421 : (!torch.int) -> !torch.list - %6423 = torch.aten.view %6420, %6422 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %6423, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_8118 = torch.constant.int 32 - %int2_8119 = torch.constant.int 2 - %int32_8120 = torch.constant.int 32 - %int8_8121 = torch.constant.int 8 - %int128_8122 = torch.constant.int 128 - %6424 = torch.prim.ListConstruct %437, %int32_8118, %int2_8119, %int32_8120, %int8_8121, %int128_8122 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6425 = torch.aten.view %6414, %6424 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6425, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_8123 = torch.constant.int 32 - %6426 = torch.aten.mul.int %437, %int32_8123 : !torch.int, !torch.int -> !torch.int - %int2_8124 = torch.constant.int 2 - %int32_8125 = torch.constant.int 32 - %int8_8126 = torch.constant.int 8 - %int128_8127 = torch.constant.int 128 - %6427 = torch.prim.ListConstruct %6426, %int2_8124, %int32_8125, %int8_8126, %int128_8127 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6428 = torch.aten.view %6425, %6427 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %6428, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_8128 = torch.constant.int 0 - %6429 = torch.aten.index_select %6428, %int0_8128, %6423 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %6429, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_8129 = torch.constant.int 4 - %int2_8130 = torch.constant.int 2 - %int32_8131 = torch.constant.int 32 - %int8_8132 = torch.constant.int 8 - %int128_8133 = torch.constant.int 128 - %6430 = torch.prim.ListConstruct %int4_8129, %358, %int2_8130, %int32_8131, %int8_8132, %int128_8133 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6431 = torch.aten.view %6429, %6430 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6431, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_8134 = torch.constant.int 0 - %int0_8135 = torch.constant.int 0 - %int9223372036854775807_8136 = torch.constant.int 9223372036854775807 - %int1_8137 = torch.constant.int 1 - %6432 = torch.aten.slice.Tensor %6431, %int0_8134, %int0_8135, %int9223372036854775807_8136, %int1_8137 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6432, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_8138 = torch.constant.int 1 - %int0_8139 = torch.constant.int 0 - %int9223372036854775807_8140 = torch.constant.int 9223372036854775807 - %int1_8141 = torch.constant.int 1 - %6433 = torch.aten.slice.Tensor %6432, %int1_8138, %int0_8139, %int9223372036854775807_8140, %int1_8141 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6433, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_8142 = torch.constant.int 2 - %int0_8143 = torch.constant.int 0 - %6434 = torch.aten.select.int %6433, %int2_8142, %int0_8143 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6434, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_8144 = torch.constant.int 32 - %6435 = torch.aten.mul.int %358, %int32_8144 : !torch.int, !torch.int -> !torch.int - %int2_8145 = torch.constant.int 2 - %int0_8146 = torch.constant.int 0 - %int1_8147 = torch.constant.int 1 - %6436 = torch.aten.slice.Tensor %6434, %int2_8145, %int0_8146, %6435, %int1_8147 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6436, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_8148 = torch.constant.int 0 - %6437 = torch.aten.clone %6436, %int0_8148 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6437, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_8149 = torch.constant.int 1 - %6438 = torch.aten.size.int %6433, %int1_8149 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_8150 = torch.constant.int 32 - %6439 = torch.aten.mul.int %6438, %int32_8150 : !torch.int, !torch.int -> !torch.int - %int4_8151 = torch.constant.int 4 - %int8_8152 = torch.constant.int 8 - %int128_8153 = torch.constant.int 128 - %6440 = torch.prim.ListConstruct %int4_8151, %6439, %int8_8152, %int128_8153 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6441 = torch.aten._unsafe_view %6437, %6440 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6441, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_8154 = torch.constant.int 0 - %int0_8155 = torch.constant.int 0 - %int9223372036854775807_8156 = torch.constant.int 9223372036854775807 - %int1_8157 = torch.constant.int 1 - %6442 = torch.aten.slice.Tensor %6441, %int0_8154, %int0_8155, %int9223372036854775807_8156, %int1_8157 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6442, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_8158 = torch.constant.int 0 - %int0_8159 = torch.constant.int 0 - %int9223372036854775807_8160 = torch.constant.int 9223372036854775807 - %int1_8161 = torch.constant.int 1 - %6443 = torch.aten.slice.Tensor %6431, %int0_8158, %int0_8159, %int9223372036854775807_8160, %int1_8161 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6443, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_8162 = torch.constant.int 1 - %int0_8163 = torch.constant.int 0 - %int9223372036854775807_8164 = torch.constant.int 9223372036854775807 - %int1_8165 = torch.constant.int 1 - %6444 = torch.aten.slice.Tensor %6443, %int1_8162, %int0_8163, %int9223372036854775807_8164, %int1_8165 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6444, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_8166 = torch.constant.int 2 - %int1_8167 = torch.constant.int 1 - %6445 = torch.aten.select.int %6444, %int2_8166, %int1_8167 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6445, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_8168 = torch.constant.int 2 - %int0_8169 = torch.constant.int 0 - %int1_8170 = torch.constant.int 1 - %6446 = torch.aten.slice.Tensor %6445, %int2_8168, %int0_8169, %6435, %int1_8170 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6446, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_8171 = torch.constant.int 0 - %6447 = torch.aten.clone %6446, %int0_8171 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6447, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_8172 = torch.constant.int 1 - %6448 = torch.aten.size.int %6444, %int1_8172 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_8173 = torch.constant.int 32 - %6449 = torch.aten.mul.int %6448, %int32_8173 : !torch.int, !torch.int -> !torch.int - %int4_8174 = torch.constant.int 4 - %int8_8175 = torch.constant.int 8 - %int128_8176 = torch.constant.int 128 - %6450 = torch.prim.ListConstruct %int4_8174, %6449, %int8_8175, %int128_8176 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6451 = torch.aten._unsafe_view %6447, %6450 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6451, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_8177 = torch.constant.int 0 - %int0_8178 = torch.constant.int 0 - %int9223372036854775807_8179 = torch.constant.int 9223372036854775807 - %int1_8180 = torch.constant.int 1 - %6452 = torch.aten.slice.Tensor %6451, %int0_8177, %int0_8178, %int9223372036854775807_8179, %int1_8180 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6452, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_8181 = torch.constant.int -2 - %6453 = torch.aten.unsqueeze %6442, %int-2_8181 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6453, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_8182 = torch.constant.int 1 - %6454 = torch.aten.size.int %6441, %int1_8182 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_8183 = torch.constant.int 4 - %int8_8184 = torch.constant.int 8 - %int4_8185 = torch.constant.int 4 - %int128_8186 = torch.constant.int 128 - %6455 = torch.prim.ListConstruct %int4_8183, %6454, %int8_8184, %int4_8185, %int128_8186 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_8187 = torch.constant.bool false - %6456 = torch.aten.expand %6453, %6455, %false_8187 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6456, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_8188 = torch.constant.int 0 - %6457 = torch.aten.clone %6456, %int0_8188 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6457, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_8189 = torch.constant.int 4 - %int32_8190 = torch.constant.int 32 - %int128_8191 = torch.constant.int 128 - %6458 = torch.prim.ListConstruct %int4_8189, %6454, %int32_8190, %int128_8191 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6459 = torch.aten._unsafe_view %6457, %6458 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6459, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_8192 = torch.constant.int -2 - %6460 = torch.aten.unsqueeze %6452, %int-2_8192 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6460, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_8193 = torch.constant.int 1 - %6461 = torch.aten.size.int %6451, %int1_8193 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_8194 = torch.constant.int 4 - %int8_8195 = torch.constant.int 8 - %int4_8196 = torch.constant.int 4 - %int128_8197 = torch.constant.int 128 - %6462 = torch.prim.ListConstruct %int4_8194, %6461, %int8_8195, %int4_8196, %int128_8197 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_8198 = torch.constant.bool false - %6463 = torch.aten.expand %6460, %6462, %false_8198 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6463, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_8199 = torch.constant.int 0 - %6464 = torch.aten.clone %6463, %int0_8199 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6464, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_8200 = torch.constant.int 4 - %int32_8201 = torch.constant.int 32 - %int128_8202 = torch.constant.int 128 - %6465 = torch.prim.ListConstruct %int4_8200, %6461, %int32_8201, %int128_8202 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6466 = torch.aten._unsafe_view %6464, %6465 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6466, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_8203 = torch.constant.int 1 - %int2_8204 = torch.constant.int 2 - %6467 = torch.aten.transpose.int %6347, %int1_8203, %int2_8204 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_8205 = torch.constant.int 1 - %int2_8206 = torch.constant.int 2 - %6468 = torch.aten.transpose.int %6459, %int1_8205, %int2_8206 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6468, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_8207 = torch.constant.int 1 - %int2_8208 = torch.constant.int 2 - %6469 = torch.aten.transpose.int %6466, %int1_8207, %int2_8208 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6469, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_8209 = torch.constant.float 0.000000e+00 - %false_8210 = torch.constant.bool false - %none_8211 = torch.constant.none - %6470:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6467, %6468, %6469, %float0.000000e00_8209, %false_8210, %368, %none_8211) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_8212 = torch.constant.int 1 - %int2_8213 = torch.constant.int 2 - %6471 = torch.aten.transpose.int %6470#0, %int1_8212, %int2_8213 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_8214 = torch.constant.int 4 - %int1_8215 = torch.constant.int 1 - %int4096_8216 = torch.constant.int 4096 - %6472 = torch.prim.ListConstruct %int4_8214, %int1_8215, %int4096_8216 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6473 = torch.aten.view %6471, %6472 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_8217 = torch.constant.int -2 - %int-1_8218 = torch.constant.int -1 - %6474 = torch.aten.transpose.int %326, %int-2_8217, %int-1_8218 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_8219 = torch.constant.int 4 - %int4096_8220 = torch.constant.int 4096 - %6475 = torch.prim.ListConstruct %int4_8219, %int4096_8220 : (!torch.int, !torch.int) -> !torch.list - %6476 = torch.aten.view %6473, %6475 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6477 = torch.aten.mm %6476, %6474 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_8221 = torch.constant.int 4 - %int1_8222 = torch.constant.int 1 - %int4096_8223 = torch.constant.int 4096 - %6478 = torch.prim.ListConstruct %int4_8221, %int1_8222, %int4096_8223 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6479 = torch.aten.view %6477, %6478 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_8224 = torch.constant.int 1 - %6480 = torch.aten.add.Tensor %6307, %6479, %int1_8224 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_8225 = torch.constant.int 6 - %6481 = torch.prims.convert_element_type %6480, %int6_8225 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_8226 = torch.constant.int 2 - %6482 = torch.aten.pow.Tensor_Scalar %6481, %int2_8226 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_8227 = torch.constant.int -1 - %6483 = torch.prim.ListConstruct %int-1_8227 : (!torch.int) -> !torch.list - %true_8228 = torch.constant.bool true - %none_8229 = torch.constant.none - %6484 = torch.aten.mean.dim %6482, %6483, %true_8228, %none_8229 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_8230 = torch.constant.float 9.9999997473787516E-6 - %int1_8231 = torch.constant.int 1 - %6485 = torch.aten.add.Scalar %6484, %float9.999990e-06_8230, %int1_8231 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %6486 = torch.aten.rsqrt %6485 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %6487 = torch.aten.mul.Tensor %6481, %6486 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_8232 = torch.constant.int 5 - %6488 = torch.prims.convert_element_type %6487, %int5_8232 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %6489 = torch.aten.mul.Tensor %327, %6488 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_8233 = torch.constant.int 5 - %6490 = torch.prims.convert_element_type %6489, %int5_8233 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_8234 = torch.constant.int -2 - %int-1_8235 = torch.constant.int -1 - %6491 = torch.aten.transpose.int %328, %int-2_8234, %int-1_8235 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_8236 = torch.constant.int 4 - %int4096_8237 = torch.constant.int 4096 - %6492 = torch.prim.ListConstruct %int4_8236, %int4096_8237 : (!torch.int, !torch.int) -> !torch.list - %6493 = torch.aten.view %6490, %6492 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6494 = torch.aten.mm %6493, %6491 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_8238 = torch.constant.int 4 - %int1_8239 = torch.constant.int 1 - %int14336_8240 = torch.constant.int 14336 - %6495 = torch.prim.ListConstruct %int4_8238, %int1_8239, %int14336_8240 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6496 = torch.aten.view %6494, %6495 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %6497 = torch.aten.silu %6496 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_8241 = torch.constant.int -2 - %int-1_8242 = torch.constant.int -1 - %6498 = torch.aten.transpose.int %329, %int-2_8241, %int-1_8242 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_8243 = torch.constant.int 4 - %int4096_8244 = torch.constant.int 4096 - %6499 = torch.prim.ListConstruct %int4_8243, %int4096_8244 : (!torch.int, !torch.int) -> !torch.list - %6500 = torch.aten.view %6490, %6499 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6501 = torch.aten.mm %6500, %6498 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_8245 = torch.constant.int 4 - %int1_8246 = torch.constant.int 1 - %int14336_8247 = torch.constant.int 14336 - %6502 = torch.prim.ListConstruct %int4_8245, %int1_8246, %int14336_8247 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6503 = torch.aten.view %6501, %6502 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %6504 = torch.aten.mul.Tensor %6497, %6503 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_8248 = torch.constant.int -2 - %int-1_8249 = torch.constant.int -1 - %6505 = torch.aten.transpose.int %330, %int-2_8248, %int-1_8249 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_8250 = torch.constant.int 4 - %int14336_8251 = torch.constant.int 14336 - %6506 = torch.prim.ListConstruct %int4_8250, %int14336_8251 : (!torch.int, !torch.int) -> !torch.list - %6507 = torch.aten.view %6504, %6506 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %6508 = torch.aten.mm %6507, %6505 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_8252 = torch.constant.int 4 - %int1_8253 = torch.constant.int 1 - %int4096_8254 = torch.constant.int 4096 - %6509 = torch.prim.ListConstruct %int4_8252, %int1_8253, %int4096_8254 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6510 = torch.aten.view %6508, %6509 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_8255 = torch.constant.int 1 - %6511 = torch.aten.add.Tensor %6480, %6510, %int1_8255 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_8256 = torch.constant.int 6 - %6512 = torch.prims.convert_element_type %6511, %int6_8256 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_8257 = torch.constant.int 2 - %6513 = torch.aten.pow.Tensor_Scalar %6512, %int2_8257 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_8258 = torch.constant.int -1 - %6514 = torch.prim.ListConstruct %int-1_8258 : (!torch.int) -> !torch.list - %true_8259 = torch.constant.bool true - %none_8260 = torch.constant.none - %6515 = torch.aten.mean.dim %6513, %6514, %true_8259, %none_8260 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_8261 = torch.constant.float 9.9999997473787516E-6 - %int1_8262 = torch.constant.int 1 - %6516 = torch.aten.add.Scalar %6515, %float9.999990e-06_8261, %int1_8262 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %6517 = torch.aten.rsqrt %6516 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %6518 = torch.aten.mul.Tensor %6512, %6517 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_8263 = torch.constant.int 5 - %6519 = torch.prims.convert_element_type %6518, %int5_8263 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %6520 = torch.aten.mul.Tensor %331, %6519 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_8264 = torch.constant.int 5 - %6521 = torch.prims.convert_element_type %6520, %int5_8264 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_8265 = torch.constant.int -2 - %int-1_8266 = torch.constant.int -1 - %6522 = torch.aten.transpose.int %332, %int-2_8265, %int-1_8266 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_8267 = torch.constant.int 4 - %int4096_8268 = torch.constant.int 4096 - %6523 = torch.prim.ListConstruct %int4_8267, %int4096_8268 : (!torch.int, !torch.int) -> !torch.list - %6524 = torch.aten.view %6521, %6523 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6525 = torch.aten.mm %6524, %6522 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_8269 = torch.constant.int 4 - %int1_8270 = torch.constant.int 1 - %int4096_8271 = torch.constant.int 4096 - %6526 = torch.prim.ListConstruct %int4_8269, %int1_8270, %int4096_8271 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6527 = torch.aten.view %6525, %6526 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_8272 = torch.constant.int -2 - %int-1_8273 = torch.constant.int -1 - %6528 = torch.aten.transpose.int %333, %int-2_8272, %int-1_8273 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_8274 = torch.constant.int 4 - %int4096_8275 = torch.constant.int 4096 - %6529 = torch.prim.ListConstruct %int4_8274, %int4096_8275 : (!torch.int, !torch.int) -> !torch.list - %6530 = torch.aten.view %6521, %6529 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6531 = torch.aten.mm %6530, %6528 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_8276 = torch.constant.int 4 - %int1_8277 = torch.constant.int 1 - %int1024_8278 = torch.constant.int 1024 - %6532 = torch.prim.ListConstruct %int4_8276, %int1_8277, %int1024_8278 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6533 = torch.aten.view %6531, %6532 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_8279 = torch.constant.int -2 - %int-1_8280 = torch.constant.int -1 - %6534 = torch.aten.transpose.int %334, %int-2_8279, %int-1_8280 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_8281 = torch.constant.int 4 - %int4096_8282 = torch.constant.int 4096 - %6535 = torch.prim.ListConstruct %int4_8281, %int4096_8282 : (!torch.int, !torch.int) -> !torch.list - %6536 = torch.aten.view %6521, %6535 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6537 = torch.aten.mm %6536, %6534 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_8283 = torch.constant.int 4 - %int1_8284 = torch.constant.int 1 - %int1024_8285 = torch.constant.int 1024 - %6538 = torch.prim.ListConstruct %int4_8283, %int1_8284, %int1024_8285 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6539 = torch.aten.view %6537, %6538 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_8286 = torch.constant.int 4 - %int1_8287 = torch.constant.int 1 - %int32_8288 = torch.constant.int 32 - %int128_8289 = torch.constant.int 128 - %6540 = torch.prim.ListConstruct %int4_8286, %int1_8287, %int32_8288, %int128_8289 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6541 = torch.aten.view %6527, %6540 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_8290 = torch.constant.int 4 - %int1_8291 = torch.constant.int 1 - %int8_8292 = torch.constant.int 8 - %int128_8293 = torch.constant.int 128 - %6542 = torch.prim.ListConstruct %int4_8290, %int1_8291, %int8_8292, %int128_8293 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6543 = torch.aten.view %6533, %6542 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_8294 = torch.constant.int 4 - %int1_8295 = torch.constant.int 1 - %int8_8296 = torch.constant.int 8 - %int128_8297 = torch.constant.int 128 - %6544 = torch.prim.ListConstruct %int4_8294, %int1_8295, %int8_8296, %int128_8297 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6545 = torch.aten.view %6539, %6544 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_8298 = torch.constant.int 6 - %6546 = torch.prims.convert_element_type %6541, %int6_8298 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %6547 = torch_c.to_builtin_tensor %6546 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %6548 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %6549 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%6547, %6548) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %6550 = torch_c.from_builtin_tensor %6549 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_8299 = torch.constant.int 5 - %6551 = torch.prims.convert_element_type %6550, %int5_8299 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_8300 = torch.constant.int 6 - %6552 = torch.prims.convert_element_type %6543, %int6_8300 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %6553 = torch_c.to_builtin_tensor %6552 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %6554 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %6555 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%6553, %6554) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %6556 = torch_c.from_builtin_tensor %6555 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_8301 = torch.constant.int 5 - %6557 = torch.prims.convert_element_type %6556, %int5_8301 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_8302 = torch.constant.int 32 - %6558 = torch.aten.floor_divide.Scalar %arg2, %int32_8302 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_8303 = torch.constant.int 1 - %6559 = torch.aten.unsqueeze %6558, %int1_8303 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8304 = torch.constant.int 1 - %false_8305 = torch.constant.bool false - %6560 = torch.aten.gather %arg3, %int1_8304, %6559, %false_8305 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_8306 = torch.constant.int 32 - %6561 = torch.aten.remainder.Scalar %arg2, %int32_8306 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_8307 = torch.constant.int 1 - %6562 = torch.aten.unsqueeze %6561, %int1_8307 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_8308 = torch.constant.none - %6563 = torch.aten.clone %335, %none_8308 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_8309 = torch.constant.int 0 - %6564 = torch.aten.unsqueeze %6563, %int0_8309 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_8310 = torch.constant.int 4 - %int1_8311 = torch.constant.int 1 - %6565 = torch.prim.ListConstruct %int4_8310, %int1_8311 : (!torch.int, !torch.int) -> !torch.list - %int1_8312 = torch.constant.int 1 - %int1_8313 = torch.constant.int 1 - %6566 = torch.prim.ListConstruct %int1_8312, %int1_8313 : (!torch.int, !torch.int) -> !torch.list - %int4_8314 = torch.constant.int 4 - %int0_8315 = torch.constant.int 0 - %cpu_8316 = torch.constant.device "cpu" - %false_8317 = torch.constant.bool false - %6567 = torch.aten.empty_strided %6565, %6566, %int4_8314, %int0_8315, %cpu_8316, %false_8317 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int30 = torch.constant.int 30 - %6568 = torch.aten.fill.Scalar %6567, %int30 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_8318 = torch.constant.int 4 - %int1_8319 = torch.constant.int 1 - %6569 = torch.prim.ListConstruct %int4_8318, %int1_8319 : (!torch.int, !torch.int) -> !torch.list - %6570 = torch.aten.repeat %6564, %6569 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_8320 = torch.constant.int 32 - %6571 = torch.aten.mul.Scalar %6560, %int32_8320 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8321 = torch.constant.int 1 - %6572 = torch.aten.add.Tensor %6571, %6568, %int1_8321 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_8322 = torch.constant.int 2 - %6573 = torch.aten.mul.Scalar %6572, %int2_8322 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8323 = torch.constant.int 1 - %6574 = torch.aten.add.Tensor %6573, %6570, %int1_8323 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_8324 = torch.constant.int 32 - %6575 = torch.aten.mul.Scalar %6574, %int32_8324 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8325 = torch.constant.int 1 - %6576 = torch.aten.add.Tensor %6575, %6562, %int1_8325 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_8326 = torch.constant.int 32 - %int2_8327 = torch.constant.int 2 - %int32_8328 = torch.constant.int 32 - %int8_8329 = torch.constant.int 8 - %int128_8330 = torch.constant.int 128 - %6577 = torch.prim.ListConstruct %437, %int32_8326, %int2_8327, %int32_8328, %int8_8329, %int128_8330 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6578 = torch.aten.view %6414, %6577 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6578, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_8331 = torch.constant.int 32 - %6579 = torch.aten.mul.int %437, %int32_8331 : !torch.int, !torch.int -> !torch.int - %int2_8332 = torch.constant.int 2 - %6580 = torch.aten.mul.int %6579, %int2_8332 : !torch.int, !torch.int -> !torch.int - %int32_8333 = torch.constant.int 32 - %6581 = torch.aten.mul.int %6580, %int32_8333 : !torch.int, !torch.int -> !torch.int - %int8_8334 = torch.constant.int 8 - %int128_8335 = torch.constant.int 128 - %6582 = torch.prim.ListConstruct %6581, %int8_8334, %int128_8335 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6583 = torch.aten.view %6578, %6582 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6583, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %6584 = torch.prim.ListConstruct %6576 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_8336 = torch.constant.bool false - %6585 = torch.aten.index_put %6583, %6584, %6557, %false_8336 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6585, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_8337 = torch.constant.int 32 - %int2_8338 = torch.constant.int 2 - %int32_8339 = torch.constant.int 32 - %int8_8340 = torch.constant.int 8 - %int128_8341 = torch.constant.int 128 - %6586 = torch.prim.ListConstruct %437, %int32_8337, %int2_8338, %int32_8339, %int8_8340, %int128_8341 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6587 = torch.aten.view %6585, %6586 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6587, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_8342 = torch.constant.int 2097152 - %6588 = torch.prim.ListConstruct %437, %int2097152_8342 : (!torch.int, !torch.int) -> !torch.list - %6589 = torch.aten.view %6587, %6588 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %6589, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_8343 = torch.constant.int 32 - %int2_8344 = torch.constant.int 2 - %int32_8345 = torch.constant.int 32 - %int8_8346 = torch.constant.int 8 - %int128_8347 = torch.constant.int 128 - %6590 = torch.prim.ListConstruct %437, %int32_8343, %int2_8344, %int32_8345, %int8_8346, %int128_8347 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6591 = torch.aten.view %6589, %6590 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6591, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_8348 = torch.constant.int 8 - %int128_8349 = torch.constant.int 128 - %6592 = torch.prim.ListConstruct %6581, %int8_8348, %int128_8349 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6593 = torch.aten.view %6591, %6592 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6593, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_8350 = torch.constant.int 32 - %6594 = torch.aten.floor_divide.Scalar %arg2, %int32_8350 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_8351 = torch.constant.int 1 - %6595 = torch.aten.unsqueeze %6594, %int1_8351 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8352 = torch.constant.int 1 - %false_8353 = torch.constant.bool false - %6596 = torch.aten.gather %arg3, %int1_8352, %6595, %false_8353 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_8354 = torch.constant.int 32 - %6597 = torch.aten.remainder.Scalar %arg2, %int32_8354 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_8355 = torch.constant.int 1 - %6598 = torch.aten.unsqueeze %6597, %int1_8355 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_8356 = torch.constant.none - %6599 = torch.aten.clone %336, %none_8356 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_8357 = torch.constant.int 0 - %6600 = torch.aten.unsqueeze %6599, %int0_8357 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_8358 = torch.constant.int 4 - %int1_8359 = torch.constant.int 1 - %6601 = torch.prim.ListConstruct %int4_8358, %int1_8359 : (!torch.int, !torch.int) -> !torch.list - %int1_8360 = torch.constant.int 1 - %int1_8361 = torch.constant.int 1 - %6602 = torch.prim.ListConstruct %int1_8360, %int1_8361 : (!torch.int, !torch.int) -> !torch.list - %int4_8362 = torch.constant.int 4 - %int0_8363 = torch.constant.int 0 - %cpu_8364 = torch.constant.device "cpu" - %false_8365 = torch.constant.bool false - %6603 = torch.aten.empty_strided %6601, %6602, %int4_8362, %int0_8363, %cpu_8364, %false_8365 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int30_8366 = torch.constant.int 30 - %6604 = torch.aten.fill.Scalar %6603, %int30_8366 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_8367 = torch.constant.int 4 - %int1_8368 = torch.constant.int 1 - %6605 = torch.prim.ListConstruct %int4_8367, %int1_8368 : (!torch.int, !torch.int) -> !torch.list - %6606 = torch.aten.repeat %6600, %6605 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_8369 = torch.constant.int 32 - %6607 = torch.aten.mul.Scalar %6596, %int32_8369 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8370 = torch.constant.int 1 - %6608 = torch.aten.add.Tensor %6607, %6604, %int1_8370 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_8371 = torch.constant.int 2 - %6609 = torch.aten.mul.Scalar %6608, %int2_8371 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8372 = torch.constant.int 1 - %6610 = torch.aten.add.Tensor %6609, %6606, %int1_8372 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_8373 = torch.constant.int 32 - %6611 = torch.aten.mul.Scalar %6610, %int32_8373 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8374 = torch.constant.int 1 - %6612 = torch.aten.add.Tensor %6611, %6598, %int1_8374 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %6613 = torch.prim.ListConstruct %6612 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_8375 = torch.constant.bool false - %6614 = torch.aten.index_put %6593, %6613, %6545, %false_8375 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6614, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_8376 = torch.constant.int 32 - %int2_8377 = torch.constant.int 2 - %int32_8378 = torch.constant.int 32 - %int8_8379 = torch.constant.int 8 - %int128_8380 = torch.constant.int 128 - %6615 = torch.prim.ListConstruct %437, %int32_8376, %int2_8377, %int32_8378, %int8_8379, %int128_8380 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6616 = torch.aten.view %6614, %6615 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6616, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_8381 = torch.constant.int 2097152 - %6617 = torch.prim.ListConstruct %437, %int2097152_8381 : (!torch.int, !torch.int) -> !torch.list - %6618 = torch.aten.view %6616, %6617 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %6618, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_8382 = torch.constant.int 4 - %6619 = torch.prim.ListConstruct %int4_8382, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_8383 = torch.constant.int 1 - %6620 = torch.prim.ListConstruct %358, %int1_8383 : (!torch.int, !torch.int) -> !torch.list - %int4_8384 = torch.constant.int 4 - %int0_8385 = torch.constant.int 0 - %cpu_8386 = torch.constant.device "cpu" - %false_8387 = torch.constant.bool false - %6621 = torch.aten.empty_strided %6619, %6620, %int4_8384, %int0_8385, %cpu_8386, %false_8387 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6621, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int30_8388 = torch.constant.int 30 - %6622 = torch.aten.fill.Scalar %6621, %int30_8388 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6622, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_8389 = torch.constant.int 32 - %6623 = torch.aten.mul.Scalar %arg3, %int32_8389 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6623, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_8390 = torch.constant.int 1 - %6624 = torch.aten.add.Tensor %6623, %6622, %int1_8390 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6624, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_8391 = torch.constant.int 4 - %6625 = torch.aten.mul.int %int4_8391, %358 : !torch.int, !torch.int -> !torch.int - %6626 = torch.prim.ListConstruct %6625 : (!torch.int) -> !torch.list - %6627 = torch.aten.view %6624, %6626 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %6627, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_8392 = torch.constant.int 32 - %int2_8393 = torch.constant.int 2 - %int32_8394 = torch.constant.int 32 - %int8_8395 = torch.constant.int 8 - %int128_8396 = torch.constant.int 128 - %6628 = torch.prim.ListConstruct %437, %int32_8392, %int2_8393, %int32_8394, %int8_8395, %int128_8396 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6629 = torch.aten.view %6618, %6628 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6629, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_8397 = torch.constant.int 32 - %6630 = torch.aten.mul.int %437, %int32_8397 : !torch.int, !torch.int -> !torch.int - %int2_8398 = torch.constant.int 2 - %int32_8399 = torch.constant.int 32 - %int8_8400 = torch.constant.int 8 - %int128_8401 = torch.constant.int 128 - %6631 = torch.prim.ListConstruct %6630, %int2_8398, %int32_8399, %int8_8400, %int128_8401 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6632 = torch.aten.view %6629, %6631 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %6632, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_8402 = torch.constant.int 0 - %6633 = torch.aten.index_select %6632, %int0_8402, %6627 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %6633, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_8403 = torch.constant.int 4 - %int2_8404 = torch.constant.int 2 - %int32_8405 = torch.constant.int 32 - %int8_8406 = torch.constant.int 8 - %int128_8407 = torch.constant.int 128 - %6634 = torch.prim.ListConstruct %int4_8403, %358, %int2_8404, %int32_8405, %int8_8406, %int128_8407 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6635 = torch.aten.view %6633, %6634 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6635, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_8408 = torch.constant.int 0 - %int0_8409 = torch.constant.int 0 - %int9223372036854775807_8410 = torch.constant.int 9223372036854775807 - %int1_8411 = torch.constant.int 1 - %6636 = torch.aten.slice.Tensor %6635, %int0_8408, %int0_8409, %int9223372036854775807_8410, %int1_8411 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6636, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_8412 = torch.constant.int 1 - %int0_8413 = torch.constant.int 0 - %int9223372036854775807_8414 = torch.constant.int 9223372036854775807 - %int1_8415 = torch.constant.int 1 - %6637 = torch.aten.slice.Tensor %6636, %int1_8412, %int0_8413, %int9223372036854775807_8414, %int1_8415 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6637, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_8416 = torch.constant.int 2 - %int0_8417 = torch.constant.int 0 - %6638 = torch.aten.select.int %6637, %int2_8416, %int0_8417 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6638, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_8418 = torch.constant.int 32 - %6639 = torch.aten.mul.int %358, %int32_8418 : !torch.int, !torch.int -> !torch.int - %int2_8419 = torch.constant.int 2 - %int0_8420 = torch.constant.int 0 - %int1_8421 = torch.constant.int 1 - %6640 = torch.aten.slice.Tensor %6638, %int2_8419, %int0_8420, %6639, %int1_8421 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6640, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_8422 = torch.constant.int 0 - %6641 = torch.aten.clone %6640, %int0_8422 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6641, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_8423 = torch.constant.int 1 - %6642 = torch.aten.size.int %6637, %int1_8423 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_8424 = torch.constant.int 32 - %6643 = torch.aten.mul.int %6642, %int32_8424 : !torch.int, !torch.int -> !torch.int - %int4_8425 = torch.constant.int 4 - %int8_8426 = torch.constant.int 8 - %int128_8427 = torch.constant.int 128 - %6644 = torch.prim.ListConstruct %int4_8425, %6643, %int8_8426, %int128_8427 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6645 = torch.aten._unsafe_view %6641, %6644 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6645, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_8428 = torch.constant.int 0 - %int0_8429 = torch.constant.int 0 - %int9223372036854775807_8430 = torch.constant.int 9223372036854775807 - %int1_8431 = torch.constant.int 1 - %6646 = torch.aten.slice.Tensor %6645, %int0_8428, %int0_8429, %int9223372036854775807_8430, %int1_8431 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6646, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_8432 = torch.constant.int 0 - %int0_8433 = torch.constant.int 0 - %int9223372036854775807_8434 = torch.constant.int 9223372036854775807 - %int1_8435 = torch.constant.int 1 - %6647 = torch.aten.slice.Tensor %6635, %int0_8432, %int0_8433, %int9223372036854775807_8434, %int1_8435 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6647, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_8436 = torch.constant.int 1 - %int0_8437 = torch.constant.int 0 - %int9223372036854775807_8438 = torch.constant.int 9223372036854775807 - %int1_8439 = torch.constant.int 1 - %6648 = torch.aten.slice.Tensor %6647, %int1_8436, %int0_8437, %int9223372036854775807_8438, %int1_8439 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6648, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_8440 = torch.constant.int 2 - %int1_8441 = torch.constant.int 1 - %6649 = torch.aten.select.int %6648, %int2_8440, %int1_8441 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6649, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_8442 = torch.constant.int 2 - %int0_8443 = torch.constant.int 0 - %int1_8444 = torch.constant.int 1 - %6650 = torch.aten.slice.Tensor %6649, %int2_8442, %int0_8443, %6639, %int1_8444 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6650, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_8445 = torch.constant.int 0 - %6651 = torch.aten.clone %6650, %int0_8445 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6651, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_8446 = torch.constant.int 1 - %6652 = torch.aten.size.int %6648, %int1_8446 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_8447 = torch.constant.int 32 - %6653 = torch.aten.mul.int %6652, %int32_8447 : !torch.int, !torch.int -> !torch.int - %int4_8448 = torch.constant.int 4 - %int8_8449 = torch.constant.int 8 - %int128_8450 = torch.constant.int 128 - %6654 = torch.prim.ListConstruct %int4_8448, %6653, %int8_8449, %int128_8450 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6655 = torch.aten._unsafe_view %6651, %6654 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6655, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_8451 = torch.constant.int 0 - %int0_8452 = torch.constant.int 0 - %int9223372036854775807_8453 = torch.constant.int 9223372036854775807 - %int1_8454 = torch.constant.int 1 - %6656 = torch.aten.slice.Tensor %6655, %int0_8451, %int0_8452, %int9223372036854775807_8453, %int1_8454 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6656, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_8455 = torch.constant.int -2 - %6657 = torch.aten.unsqueeze %6646, %int-2_8455 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6657, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_8456 = torch.constant.int 1 - %6658 = torch.aten.size.int %6645, %int1_8456 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_8457 = torch.constant.int 4 - %int8_8458 = torch.constant.int 8 - %int4_8459 = torch.constant.int 4 - %int128_8460 = torch.constant.int 128 - %6659 = torch.prim.ListConstruct %int4_8457, %6658, %int8_8458, %int4_8459, %int128_8460 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_8461 = torch.constant.bool false - %6660 = torch.aten.expand %6657, %6659, %false_8461 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6660, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_8462 = torch.constant.int 0 - %6661 = torch.aten.clone %6660, %int0_8462 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6661, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_8463 = torch.constant.int 4 - %int32_8464 = torch.constant.int 32 - %int128_8465 = torch.constant.int 128 - %6662 = torch.prim.ListConstruct %int4_8463, %6658, %int32_8464, %int128_8465 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6663 = torch.aten._unsafe_view %6661, %6662 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6663, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_8466 = torch.constant.int -2 - %6664 = torch.aten.unsqueeze %6656, %int-2_8466 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6664, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_8467 = torch.constant.int 1 - %6665 = torch.aten.size.int %6655, %int1_8467 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_8468 = torch.constant.int 4 - %int8_8469 = torch.constant.int 8 - %int4_8470 = torch.constant.int 4 - %int128_8471 = torch.constant.int 128 - %6666 = torch.prim.ListConstruct %int4_8468, %6665, %int8_8469, %int4_8470, %int128_8471 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_8472 = torch.constant.bool false - %6667 = torch.aten.expand %6664, %6666, %false_8472 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6667, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_8473 = torch.constant.int 0 - %6668 = torch.aten.clone %6667, %int0_8473 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6668, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_8474 = torch.constant.int 4 - %int32_8475 = torch.constant.int 32 - %int128_8476 = torch.constant.int 128 - %6669 = torch.prim.ListConstruct %int4_8474, %6665, %int32_8475, %int128_8476 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6670 = torch.aten._unsafe_view %6668, %6669 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6670, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_8477 = torch.constant.int 1 - %int2_8478 = torch.constant.int 2 - %6671 = torch.aten.transpose.int %6551, %int1_8477, %int2_8478 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_8479 = torch.constant.int 1 - %int2_8480 = torch.constant.int 2 - %6672 = torch.aten.transpose.int %6663, %int1_8479, %int2_8480 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6672, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_8481 = torch.constant.int 1 - %int2_8482 = torch.constant.int 2 - %6673 = torch.aten.transpose.int %6670, %int1_8481, %int2_8482 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6673, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_8483 = torch.constant.float 0.000000e+00 - %false_8484 = torch.constant.bool false - %none_8485 = torch.constant.none - %6674:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6671, %6672, %6673, %float0.000000e00_8483, %false_8484, %368, %none_8485) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_8486 = torch.constant.int 1 - %int2_8487 = torch.constant.int 2 - %6675 = torch.aten.transpose.int %6674#0, %int1_8486, %int2_8487 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_8488 = torch.constant.int 4 - %int1_8489 = torch.constant.int 1 - %int4096_8490 = torch.constant.int 4096 - %6676 = torch.prim.ListConstruct %int4_8488, %int1_8489, %int4096_8490 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6677 = torch.aten.view %6675, %6676 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_8491 = torch.constant.int -2 - %int-1_8492 = torch.constant.int -1 - %6678 = torch.aten.transpose.int %337, %int-2_8491, %int-1_8492 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_8493 = torch.constant.int 4 - %int4096_8494 = torch.constant.int 4096 - %6679 = torch.prim.ListConstruct %int4_8493, %int4096_8494 : (!torch.int, !torch.int) -> !torch.list - %6680 = torch.aten.view %6677, %6679 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6681 = torch.aten.mm %6680, %6678 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_8495 = torch.constant.int 4 - %int1_8496 = torch.constant.int 1 - %int4096_8497 = torch.constant.int 4096 - %6682 = torch.prim.ListConstruct %int4_8495, %int1_8496, %int4096_8497 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6683 = torch.aten.view %6681, %6682 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_8498 = torch.constant.int 1 - %6684 = torch.aten.add.Tensor %6511, %6683, %int1_8498 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_8499 = torch.constant.int 6 - %6685 = torch.prims.convert_element_type %6684, %int6_8499 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_8500 = torch.constant.int 2 - %6686 = torch.aten.pow.Tensor_Scalar %6685, %int2_8500 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_8501 = torch.constant.int -1 - %6687 = torch.prim.ListConstruct %int-1_8501 : (!torch.int) -> !torch.list - %true_8502 = torch.constant.bool true - %none_8503 = torch.constant.none - %6688 = torch.aten.mean.dim %6686, %6687, %true_8502, %none_8503 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_8504 = torch.constant.float 9.9999997473787516E-6 - %int1_8505 = torch.constant.int 1 - %6689 = torch.aten.add.Scalar %6688, %float9.999990e-06_8504, %int1_8505 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %6690 = torch.aten.rsqrt %6689 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %6691 = torch.aten.mul.Tensor %6685, %6690 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_8506 = torch.constant.int 5 - %6692 = torch.prims.convert_element_type %6691, %int5_8506 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %6693 = torch.aten.mul.Tensor %338, %6692 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_8507 = torch.constant.int 5 - %6694 = torch.prims.convert_element_type %6693, %int5_8507 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_8508 = torch.constant.int -2 - %int-1_8509 = torch.constant.int -1 - %6695 = torch.aten.transpose.int %339, %int-2_8508, %int-1_8509 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_8510 = torch.constant.int 4 - %int4096_8511 = torch.constant.int 4096 - %6696 = torch.prim.ListConstruct %int4_8510, %int4096_8511 : (!torch.int, !torch.int) -> !torch.list - %6697 = torch.aten.view %6694, %6696 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6698 = torch.aten.mm %6697, %6695 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_8512 = torch.constant.int 4 - %int1_8513 = torch.constant.int 1 - %int14336_8514 = torch.constant.int 14336 - %6699 = torch.prim.ListConstruct %int4_8512, %int1_8513, %int14336_8514 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6700 = torch.aten.view %6698, %6699 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %6701 = torch.aten.silu %6700 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_8515 = torch.constant.int -2 - %int-1_8516 = torch.constant.int -1 - %6702 = torch.aten.transpose.int %340, %int-2_8515, %int-1_8516 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_8517 = torch.constant.int 4 - %int4096_8518 = torch.constant.int 4096 - %6703 = torch.prim.ListConstruct %int4_8517, %int4096_8518 : (!torch.int, !torch.int) -> !torch.list - %6704 = torch.aten.view %6694, %6703 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6705 = torch.aten.mm %6704, %6702 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_8519 = torch.constant.int 4 - %int1_8520 = torch.constant.int 1 - %int14336_8521 = torch.constant.int 14336 - %6706 = torch.prim.ListConstruct %int4_8519, %int1_8520, %int14336_8521 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6707 = torch.aten.view %6705, %6706 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %6708 = torch.aten.mul.Tensor %6701, %6707 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_8522 = torch.constant.int -2 - %int-1_8523 = torch.constant.int -1 - %6709 = torch.aten.transpose.int %341, %int-2_8522, %int-1_8523 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_8524 = torch.constant.int 4 - %int14336_8525 = torch.constant.int 14336 - %6710 = torch.prim.ListConstruct %int4_8524, %int14336_8525 : (!torch.int, !torch.int) -> !torch.list - %6711 = torch.aten.view %6708, %6710 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %6712 = torch.aten.mm %6711, %6709 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_8526 = torch.constant.int 4 - %int1_8527 = torch.constant.int 1 - %int4096_8528 = torch.constant.int 4096 - %6713 = torch.prim.ListConstruct %int4_8526, %int1_8527, %int4096_8528 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6714 = torch.aten.view %6712, %6713 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_8529 = torch.constant.int 1 - %6715 = torch.aten.add.Tensor %6684, %6714, %int1_8529 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_8530 = torch.constant.int 6 - %6716 = torch.prims.convert_element_type %6715, %int6_8530 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_8531 = torch.constant.int 2 - %6717 = torch.aten.pow.Tensor_Scalar %6716, %int2_8531 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_8532 = torch.constant.int -1 - %6718 = torch.prim.ListConstruct %int-1_8532 : (!torch.int) -> !torch.list - %true_8533 = torch.constant.bool true - %none_8534 = torch.constant.none - %6719 = torch.aten.mean.dim %6717, %6718, %true_8533, %none_8534 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_8535 = torch.constant.float 9.9999997473787516E-6 - %int1_8536 = torch.constant.int 1 - %6720 = torch.aten.add.Scalar %6719, %float9.999990e-06_8535, %int1_8536 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %6721 = torch.aten.rsqrt %6720 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %6722 = torch.aten.mul.Tensor %6716, %6721 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_8537 = torch.constant.int 5 - %6723 = torch.prims.convert_element_type %6722, %int5_8537 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %6724 = torch.aten.mul.Tensor %342, %6723 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_8538 = torch.constant.int 5 - %6725 = torch.prims.convert_element_type %6724, %int5_8538 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_8539 = torch.constant.int -2 - %int-1_8540 = torch.constant.int -1 - %6726 = torch.aten.transpose.int %343, %int-2_8539, %int-1_8540 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_8541 = torch.constant.int 4 - %int4096_8542 = torch.constant.int 4096 - %6727 = torch.prim.ListConstruct %int4_8541, %int4096_8542 : (!torch.int, !torch.int) -> !torch.list - %6728 = torch.aten.view %6725, %6727 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6729 = torch.aten.mm %6728, %6726 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_8543 = torch.constant.int 4 - %int1_8544 = torch.constant.int 1 - %int4096_8545 = torch.constant.int 4096 - %6730 = torch.prim.ListConstruct %int4_8543, %int1_8544, %int4096_8545 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6731 = torch.aten.view %6729, %6730 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_8546 = torch.constant.int -2 - %int-1_8547 = torch.constant.int -1 - %6732 = torch.aten.transpose.int %344, %int-2_8546, %int-1_8547 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_8548 = torch.constant.int 4 - %int4096_8549 = torch.constant.int 4096 - %6733 = torch.prim.ListConstruct %int4_8548, %int4096_8549 : (!torch.int, !torch.int) -> !torch.list - %6734 = torch.aten.view %6725, %6733 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6735 = torch.aten.mm %6734, %6732 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_8550 = torch.constant.int 4 - %int1_8551 = torch.constant.int 1 - %int1024_8552 = torch.constant.int 1024 - %6736 = torch.prim.ListConstruct %int4_8550, %int1_8551, %int1024_8552 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6737 = torch.aten.view %6735, %6736 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int-2_8553 = torch.constant.int -2 - %int-1_8554 = torch.constant.int -1 - %6738 = torch.aten.transpose.int %345, %int-2_8553, %int-1_8554 : !torch.vtensor<[1024,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,1024],f16> - %int4_8555 = torch.constant.int 4 - %int4096_8556 = torch.constant.int 4096 - %6739 = torch.prim.ListConstruct %int4_8555, %int4096_8556 : (!torch.int, !torch.int) -> !torch.list - %6740 = torch.aten.view %6725, %6739 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6741 = torch.aten.mm %6740, %6738 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,1024],f16> -> !torch.vtensor<[4,1024],f16> - %int4_8557 = torch.constant.int 4 - %int1_8558 = torch.constant.int 1 - %int1024_8559 = torch.constant.int 1024 - %6742 = torch.prim.ListConstruct %int4_8557, %int1_8558, %int1024_8559 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6743 = torch.aten.view %6741, %6742 : !torch.vtensor<[4,1024],f16>, !torch.list -> !torch.vtensor<[4,1,1024],f16> - %int4_8560 = torch.constant.int 4 - %int1_8561 = torch.constant.int 1 - %int32_8562 = torch.constant.int 32 - %int128_8563 = torch.constant.int 128 - %6744 = torch.prim.ListConstruct %int4_8560, %int1_8561, %int32_8562, %int128_8563 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6745 = torch.aten.view %6731, %6744 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,1,32,128],f16> - %int4_8564 = torch.constant.int 4 - %int1_8565 = torch.constant.int 1 - %int8_8566 = torch.constant.int 8 - %int128_8567 = torch.constant.int 128 - %6746 = torch.prim.ListConstruct %int4_8564, %int1_8565, %int8_8566, %int128_8567 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6747 = torch.aten.view %6737, %6746 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int4_8568 = torch.constant.int 4 - %int1_8569 = torch.constant.int 1 - %int8_8570 = torch.constant.int 8 - %int128_8571 = torch.constant.int 128 - %6748 = torch.prim.ListConstruct %int4_8568, %int1_8569, %int8_8570, %int128_8571 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6749 = torch.aten.view %6743, %6748 : !torch.vtensor<[4,1,1024],f16>, !torch.list -> !torch.vtensor<[4,1,8,128],f16> - %int6_8572 = torch.constant.int 6 - %6750 = torch.prims.convert_element_type %6745, %int6_8572 : !torch.vtensor<[4,1,32,128],f16>, !torch.int -> !torch.vtensor<[4,1,32,128],f32> - %6751 = torch_c.to_builtin_tensor %6750 : !torch.vtensor<[4,1,32,128],f32> -> tensor<4x1x32x128xf32> - %6752 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %6753 = util.call @sharktank_rotary_embedding_4_1_32_128_f32(%6751, %6752) : (tensor<4x1x32x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> - %6754 = torch_c.from_builtin_tensor %6753 : tensor<4x1x32x128xf32> -> !torch.vtensor<[4,1,32,128],f32> - %int5_8573 = torch.constant.int 5 - %6755 = torch.prims.convert_element_type %6754, %int5_8573 : !torch.vtensor<[4,1,32,128],f32>, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int6_8574 = torch.constant.int 6 - %6756 = torch.prims.convert_element_type %6747, %int6_8574 : !torch.vtensor<[4,1,8,128],f16>, !torch.int -> !torch.vtensor<[4,1,8,128],f32> - %6757 = torch_c.to_builtin_tensor %6756 : !torch.vtensor<[4,1,8,128],f32> -> tensor<4x1x8x128xf32> - %6758 = torch_c.to_builtin_tensor %389 : !torch.vtensor<[4,1,128],f32> -> tensor<4x1x128xf32> - %6759 = util.call @sharktank_rotary_embedding_4_1_8_128_f32(%6757, %6758) : (tensor<4x1x8x128xf32>, tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> - %6760 = torch_c.from_builtin_tensor %6759 : tensor<4x1x8x128xf32> -> !torch.vtensor<[4,1,8,128],f32> - %int5_8575 = torch.constant.int 5 - %6761 = torch.prims.convert_element_type %6760, %int5_8575 : !torch.vtensor<[4,1,8,128],f32>, !torch.int -> !torch.vtensor<[4,1,8,128],f16> - %int32_8576 = torch.constant.int 32 - %6762 = torch.aten.floor_divide.Scalar %arg2, %int32_8576 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_8577 = torch.constant.int 1 - %6763 = torch.aten.unsqueeze %6762, %int1_8577 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8578 = torch.constant.int 1 - %false_8579 = torch.constant.bool false - %6764 = torch.aten.gather %arg3, %int1_8578, %6763, %false_8579 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_8580 = torch.constant.int 32 - %6765 = torch.aten.remainder.Scalar %arg2, %int32_8580 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_8581 = torch.constant.int 1 - %6766 = torch.aten.unsqueeze %6765, %int1_8581 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_8582 = torch.constant.none - %6767 = torch.aten.clone %346, %none_8582 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_8583 = torch.constant.int 0 - %6768 = torch.aten.unsqueeze %6767, %int0_8583 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_8584 = torch.constant.int 4 - %int1_8585 = torch.constant.int 1 - %6769 = torch.prim.ListConstruct %int4_8584, %int1_8585 : (!torch.int, !torch.int) -> !torch.list - %int1_8586 = torch.constant.int 1 - %int1_8587 = torch.constant.int 1 - %6770 = torch.prim.ListConstruct %int1_8586, %int1_8587 : (!torch.int, !torch.int) -> !torch.list - %int4_8588 = torch.constant.int 4 - %int0_8589 = torch.constant.int 0 - %cpu_8590 = torch.constant.device "cpu" - %false_8591 = torch.constant.bool false - %6771 = torch.aten.empty_strided %6769, %6770, %int4_8588, %int0_8589, %cpu_8590, %false_8591 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int31 = torch.constant.int 31 - %6772 = torch.aten.fill.Scalar %6771, %int31 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_8592 = torch.constant.int 4 - %int1_8593 = torch.constant.int 1 - %6773 = torch.prim.ListConstruct %int4_8592, %int1_8593 : (!torch.int, !torch.int) -> !torch.list - %6774 = torch.aten.repeat %6768, %6773 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_8594 = torch.constant.int 32 - %6775 = torch.aten.mul.Scalar %6764, %int32_8594 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8595 = torch.constant.int 1 - %6776 = torch.aten.add.Tensor %6775, %6772, %int1_8595 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_8596 = torch.constant.int 2 - %6777 = torch.aten.mul.Scalar %6776, %int2_8596 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8597 = torch.constant.int 1 - %6778 = torch.aten.add.Tensor %6777, %6774, %int1_8597 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_8598 = torch.constant.int 32 - %6779 = torch.aten.mul.Scalar %6778, %int32_8598 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8599 = torch.constant.int 1 - %6780 = torch.aten.add.Tensor %6779, %6766, %int1_8599 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_8600 = torch.constant.int 32 - %int2_8601 = torch.constant.int 2 - %int32_8602 = torch.constant.int 32 - %int8_8603 = torch.constant.int 8 - %int128_8604 = torch.constant.int 128 - %6781 = torch.prim.ListConstruct %437, %int32_8600, %int2_8601, %int32_8602, %int8_8603, %int128_8604 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6782 = torch.aten.view %6618, %6781 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6782, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_8605 = torch.constant.int 32 - %6783 = torch.aten.mul.int %437, %int32_8605 : !torch.int, !torch.int -> !torch.int - %int2_8606 = torch.constant.int 2 - %6784 = torch.aten.mul.int %6783, %int2_8606 : !torch.int, !torch.int -> !torch.int - %int32_8607 = torch.constant.int 32 - %6785 = torch.aten.mul.int %6784, %int32_8607 : !torch.int, !torch.int -> !torch.int - %int8_8608 = torch.constant.int 8 - %int128_8609 = torch.constant.int 128 - %6786 = torch.prim.ListConstruct %6785, %int8_8608, %int128_8609 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6787 = torch.aten.view %6782, %6786 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6787, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %6788 = torch.prim.ListConstruct %6780 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_8610 = torch.constant.bool false - %6789 = torch.aten.index_put %6787, %6788, %6761, %false_8610 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6789, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_8611 = torch.constant.int 32 - %int2_8612 = torch.constant.int 2 - %int32_8613 = torch.constant.int 32 - %int8_8614 = torch.constant.int 8 - %int128_8615 = torch.constant.int 128 - %6790 = torch.prim.ListConstruct %437, %int32_8611, %int2_8612, %int32_8613, %int8_8614, %int128_8615 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6791 = torch.aten.view %6789, %6790 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6791, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_8616 = torch.constant.int 2097152 - %6792 = torch.prim.ListConstruct %437, %int2097152_8616 : (!torch.int, !torch.int) -> !torch.list - %6793 = torch.aten.view %6791, %6792 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.bind_symbolic_shape %6793, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int32_8617 = torch.constant.int 32 - %int2_8618 = torch.constant.int 2 - %int32_8619 = torch.constant.int 32 - %int8_8620 = torch.constant.int 8 - %int128_8621 = torch.constant.int 128 - %6794 = torch.prim.ListConstruct %437, %int32_8617, %int2_8618, %int32_8619, %int8_8620, %int128_8621 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6795 = torch.aten.view %6793, %6794 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6795, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int8_8622 = torch.constant.int 8 - %int128_8623 = torch.constant.int 128 - %6796 = torch.prim.ListConstruct %6785, %int8_8622, %int128_8623 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6797 = torch.aten.view %6795, %6796 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6797, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_8624 = torch.constant.int 32 - %6798 = torch.aten.floor_divide.Scalar %arg2, %int32_8624 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_8625 = torch.constant.int 1 - %6799 = torch.aten.unsqueeze %6798, %int1_8625 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8626 = torch.constant.int 1 - %false_8627 = torch.constant.bool false - %6800 = torch.aten.gather %arg3, %int1_8626, %6799, %false_8627 : !torch.vtensor<[4,?],si64>, !torch.int, !torch.vtensor<[4,1],si64>, !torch.bool -> !torch.vtensor<[4,1],si64> - %int32_8628 = torch.constant.int 32 - %6801 = torch.aten.remainder.Scalar %arg2, %int32_8628 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> - %int1_8629 = torch.constant.int 1 - %6802 = torch.aten.unsqueeze %6801, %int1_8629 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %none_8630 = torch.constant.none - %6803 = torch.aten.clone %347, %none_8630 : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64> - %int0_8631 = torch.constant.int 0 - %6804 = torch.aten.unsqueeze %6803, %int0_8631 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> - %int4_8632 = torch.constant.int 4 - %int1_8633 = torch.constant.int 1 - %6805 = torch.prim.ListConstruct %int4_8632, %int1_8633 : (!torch.int, !torch.int) -> !torch.list - %int1_8634 = torch.constant.int 1 - %int1_8635 = torch.constant.int 1 - %6806 = torch.prim.ListConstruct %int1_8634, %int1_8635 : (!torch.int, !torch.int) -> !torch.list - %int4_8636 = torch.constant.int 4 - %int0_8637 = torch.constant.int 0 - %cpu_8638 = torch.constant.device "cpu" - %false_8639 = torch.constant.bool false - %6807 = torch.aten.empty_strided %6805, %6806, %int4_8636, %int0_8637, %cpu_8638, %false_8639 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,1],si64> - %int31_8640 = torch.constant.int 31 - %6808 = torch.aten.fill.Scalar %6807, %int31_8640 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int4_8641 = torch.constant.int 4 - %int1_8642 = torch.constant.int 1 - %6809 = torch.prim.ListConstruct %int4_8641, %int1_8642 : (!torch.int, !torch.int) -> !torch.list - %6810 = torch.aten.repeat %6804, %6809 : !torch.vtensor<[1],si64>, !torch.list -> !torch.vtensor<[4,1],si64> - %int32_8643 = torch.constant.int 32 - %6811 = torch.aten.mul.Scalar %6800, %int32_8643 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8644 = torch.constant.int 1 - %6812 = torch.aten.add.Tensor %6811, %6808, %int1_8644 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int2_8645 = torch.constant.int 2 - %6813 = torch.aten.mul.Scalar %6812, %int2_8645 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8646 = torch.constant.int 1 - %6814 = torch.aten.add.Tensor %6813, %6810, %int1_8646 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int32_8647 = torch.constant.int 32 - %6815 = torch.aten.mul.Scalar %6814, %int32_8647 : !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %int1_8648 = torch.constant.int 1 - %6816 = torch.aten.add.Tensor %6815, %6802, %int1_8648 : !torch.vtensor<[4,1],si64>, !torch.vtensor<[4,1],si64>, !torch.int -> !torch.vtensor<[4,1],si64> - %6817 = torch.prim.ListConstruct %6816 : (!torch.vtensor<[4,1],si64>) -> !torch.list> - %false_8649 = torch.constant.bool false - %6818 = torch.aten.index_put %6797, %6817, %6749, %false_8649 : !torch.vtensor<[?,8,128],f16>, !torch.list>, !torch.vtensor<[4,1,8,128],f16>, !torch.bool -> !torch.vtensor<[?,8,128],f16> - torch.bind_symbolic_shape %6818, [%357], affine_map<()[s0] -> (s0 * 2048, 8, 128)> : !torch.vtensor<[?,8,128],f16> - %int32_8650 = torch.constant.int 32 - %int2_8651 = torch.constant.int 2 - %int32_8652 = torch.constant.int 32 - %int8_8653 = torch.constant.int 8 - %int128_8654 = torch.constant.int 128 - %6819 = torch.prim.ListConstruct %437, %int32_8650, %int2_8651, %int32_8652, %int8_8653, %int128_8654 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6820 = torch.aten.view %6818, %6819 : !torch.vtensor<[?,8,128],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6820, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int2097152_8655 = torch.constant.int 2097152 - %6821 = torch.prim.ListConstruct %437, %int2097152_8655 : (!torch.int, !torch.int) -> !torch.list - %6822 = torch.aten.view %6820, %6821 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2097152],f16> - torch.overwrite.tensor.contents %6822 overwrites %arg4 : !torch.vtensor<[?,2097152],f16>, !torch.tensor<[?,2097152],f16> - torch.bind_symbolic_shape %6822, [%357], affine_map<()[s0] -> (s0, 2097152)> : !torch.vtensor<[?,2097152],f16> - %int4_8656 = torch.constant.int 4 - %6823 = torch.prim.ListConstruct %int4_8656, %358 : (!torch.int, !torch.int) -> !torch.list - %int1_8657 = torch.constant.int 1 - %6824 = torch.prim.ListConstruct %358, %int1_8657 : (!torch.int, !torch.int) -> !torch.list - %int4_8658 = torch.constant.int 4 - %int0_8659 = torch.constant.int 0 - %cpu_8660 = torch.constant.device "cpu" - %false_8661 = torch.constant.bool false - %6825 = torch.aten.empty_strided %6823, %6824, %int4_8658, %int0_8659, %cpu_8660, %false_8661 : !torch.list, !torch.list, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6825, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int31_8662 = torch.constant.int 31 - %6826 = torch.aten.fill.Scalar %6825, %int31_8662 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6826, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int32_8663 = torch.constant.int 32 - %6827 = torch.aten.mul.Scalar %arg3, %int32_8663 : !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6827, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int1_8664 = torch.constant.int 1 - %6828 = torch.aten.add.Tensor %6827, %6826, %int1_8664 : !torch.vtensor<[4,?],si64>, !torch.vtensor<[4,?],si64>, !torch.int -> !torch.vtensor<[4,?],si64> - torch.bind_symbolic_shape %6828, [%356], affine_map<()[s0] -> (4, s0)> : !torch.vtensor<[4,?],si64> - %int4_8665 = torch.constant.int 4 - %6829 = torch.aten.mul.int %int4_8665, %358 : !torch.int, !torch.int -> !torch.int - %6830 = torch.prim.ListConstruct %6829 : (!torch.int) -> !torch.list - %6831 = torch.aten.view %6828, %6830 : !torch.vtensor<[4,?],si64>, !torch.list -> !torch.vtensor<[?],si64> - torch.bind_symbolic_shape %6831, [%356], affine_map<()[s0] -> (s0 * 4)> : !torch.vtensor<[?],si64> - %int32_8666 = torch.constant.int 32 - %int2_8667 = torch.constant.int 2 - %int32_8668 = torch.constant.int 32 - %int8_8669 = torch.constant.int 8 - %int128_8670 = torch.constant.int 128 - %6832 = torch.prim.ListConstruct %437, %int32_8666, %int2_8667, %int32_8668, %int8_8669, %int128_8670 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6833 = torch.aten.view %6822, %6832 : !torch.vtensor<[?,2097152],f16>, !torch.list -> !torch.vtensor<[?,32,2,32,8,128],f16> - torch.bind_symbolic_shape %6833, [%357], affine_map<()[s0] -> (s0, 32, 2, 32, 8, 128)> : !torch.vtensor<[?,32,2,32,8,128],f16> - %int32_8671 = torch.constant.int 32 - %6834 = torch.aten.mul.int %437, %int32_8671 : !torch.int, !torch.int -> !torch.int - %int2_8672 = torch.constant.int 2 - %int32_8673 = torch.constant.int 32 - %int8_8674 = torch.constant.int 8 - %int128_8675 = torch.constant.int 128 - %6835 = torch.prim.ListConstruct %6834, %int2_8672, %int32_8673, %int8_8674, %int128_8675 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6836 = torch.aten.view %6833, %6835 : !torch.vtensor<[?,32,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %6836, [%357], affine_map<()[s0] -> (s0 * 32, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int0_8676 = torch.constant.int 0 - %6837 = torch.aten.index_select %6836, %int0_8676, %6831 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,2,32,8,128],f16> - torch.bind_symbolic_shape %6837, [%356], affine_map<()[s0] -> (s0 * 4, 2, 32, 8, 128)> : !torch.vtensor<[?,2,32,8,128],f16> - %int4_8677 = torch.constant.int 4 - %int2_8678 = torch.constant.int 2 - %int32_8679 = torch.constant.int 32 - %int8_8680 = torch.constant.int 8 - %int128_8681 = torch.constant.int 128 - %6838 = torch.prim.ListConstruct %int4_8677, %358, %int2_8678, %int32_8679, %int8_8680, %int128_8681 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6839 = torch.aten.view %6837, %6838 : !torch.vtensor<[?,2,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6839, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int0_8682 = torch.constant.int 0 - %int0_8683 = torch.constant.int 0 - %int9223372036854775807_8684 = torch.constant.int 9223372036854775807 - %int1_8685 = torch.constant.int 1 - %6840 = torch.aten.slice.Tensor %6839, %int0_8682, %int0_8683, %int9223372036854775807_8684, %int1_8685 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6840, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_8686 = torch.constant.int 1 - %int0_8687 = torch.constant.int 0 - %int9223372036854775807_8688 = torch.constant.int 9223372036854775807 - %int1_8689 = torch.constant.int 1 - %6841 = torch.aten.slice.Tensor %6840, %int1_8686, %int0_8687, %int9223372036854775807_8688, %int1_8689 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6841, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_8690 = torch.constant.int 2 - %int0_8691 = torch.constant.int 0 - %6842 = torch.aten.select.int %6841, %int2_8690, %int0_8691 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6842, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int32_8692 = torch.constant.int 32 - %6843 = torch.aten.mul.int %358, %int32_8692 : !torch.int, !torch.int -> !torch.int - %int2_8693 = torch.constant.int 2 - %int0_8694 = torch.constant.int 0 - %int1_8695 = torch.constant.int 1 - %6844 = torch.aten.slice.Tensor %6842, %int2_8693, %int0_8694, %6843, %int1_8695 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6844, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_8696 = torch.constant.int 0 - %6845 = torch.aten.clone %6844, %int0_8696 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6845, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_8697 = torch.constant.int 1 - %6846 = torch.aten.size.int %6841, %int1_8697 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_8698 = torch.constant.int 32 - %6847 = torch.aten.mul.int %6846, %int32_8698 : !torch.int, !torch.int -> !torch.int - %int4_8699 = torch.constant.int 4 - %int8_8700 = torch.constant.int 8 - %int128_8701 = torch.constant.int 128 - %6848 = torch.prim.ListConstruct %int4_8699, %6847, %int8_8700, %int128_8701 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6849 = torch.aten._unsafe_view %6845, %6848 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6849, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_8702 = torch.constant.int 0 - %int0_8703 = torch.constant.int 0 - %int9223372036854775807_8704 = torch.constant.int 9223372036854775807 - %int1_8705 = torch.constant.int 1 - %6850 = torch.aten.slice.Tensor %6849, %int0_8702, %int0_8703, %int9223372036854775807_8704, %int1_8705 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6850, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_8706 = torch.constant.int 0 - %int0_8707 = torch.constant.int 0 - %int9223372036854775807_8708 = torch.constant.int 9223372036854775807 - %int1_8709 = torch.constant.int 1 - %6851 = torch.aten.slice.Tensor %6839, %int0_8706, %int0_8707, %int9223372036854775807_8708, %int1_8709 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6851, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int1_8710 = torch.constant.int 1 - %int0_8711 = torch.constant.int 0 - %int9223372036854775807_8712 = torch.constant.int 9223372036854775807 - %int1_8713 = torch.constant.int 1 - %6852 = torch.aten.slice.Tensor %6851, %int1_8710, %int0_8711, %int9223372036854775807_8712, %int1_8713 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,2,32,8,128],f16> - torch.bind_symbolic_shape %6852, [%356], affine_map<()[s0] -> (4, s0, 2, 32, 8, 128)> : !torch.vtensor<[4,?,2,32,8,128],f16> - %int2_8714 = torch.constant.int 2 - %int1_8715 = torch.constant.int 1 - %6853 = torch.aten.select.int %6852, %int2_8714, %int1_8715 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6853, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int2_8716 = torch.constant.int 2 - %int0_8717 = torch.constant.int 0 - %int1_8718 = torch.constant.int 1 - %6854 = torch.aten.slice.Tensor %6853, %int2_8716, %int0_8717, %6843, %int1_8718 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6854, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int0_8719 = torch.constant.int 0 - %6855 = torch.aten.clone %6854, %int0_8719 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,32,8,128],f16> - torch.bind_symbolic_shape %6855, [%356], affine_map<()[s0] -> (4, s0, 32, 8, 128)> : !torch.vtensor<[4,?,32,8,128],f16> - %int1_8720 = torch.constant.int 1 - %6856 = torch.aten.size.int %6852, %int1_8720 : !torch.vtensor<[4,?,2,32,8,128],f16>, !torch.int -> !torch.int - %int32_8721 = torch.constant.int 32 - %6857 = torch.aten.mul.int %6856, %int32_8721 : !torch.int, !torch.int -> !torch.int - %int4_8722 = torch.constant.int 4 - %int8_8723 = torch.constant.int 8 - %int128_8724 = torch.constant.int 128 - %6858 = torch.prim.ListConstruct %int4_8722, %6857, %int8_8723, %int128_8724 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6859 = torch.aten._unsafe_view %6855, %6858 : !torch.vtensor<[4,?,32,8,128],f16>, !torch.list -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6859, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int0_8725 = torch.constant.int 0 - %int0_8726 = torch.constant.int 0 - %int9223372036854775807_8727 = torch.constant.int 9223372036854775807 - %int1_8728 = torch.constant.int 1 - %6860 = torch.aten.slice.Tensor %6859, %int0_8725, %int0_8726, %int9223372036854775807_8727, %int1_8728 : !torch.vtensor<[4,?,8,128],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,?,8,128],f16> - torch.bind_symbolic_shape %6860, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 128)> : !torch.vtensor<[4,?,8,128],f16> - %int-2_8729 = torch.constant.int -2 - %6861 = torch.aten.unsqueeze %6850, %int-2_8729 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6861, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_8730 = torch.constant.int 1 - %6862 = torch.aten.size.int %6849, %int1_8730 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_8731 = torch.constant.int 4 - %int8_8732 = torch.constant.int 8 - %int4_8733 = torch.constant.int 4 - %int128_8734 = torch.constant.int 128 - %6863 = torch.prim.ListConstruct %int4_8731, %6862, %int8_8732, %int4_8733, %int128_8734 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_8735 = torch.constant.bool false - %6864 = torch.aten.expand %6861, %6863, %false_8735 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6864, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_8736 = torch.constant.int 0 - %6865 = torch.aten.clone %6864, %int0_8736 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6865, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_8737 = torch.constant.int 4 - %int32_8738 = torch.constant.int 32 - %int128_8739 = torch.constant.int 128 - %6866 = torch.prim.ListConstruct %int4_8737, %6862, %int32_8738, %int128_8739 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6867 = torch.aten._unsafe_view %6865, %6866 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6867, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int-2_8740 = torch.constant.int -2 - %6868 = torch.aten.unsqueeze %6860, %int-2_8740 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,1,128],f16> - torch.bind_symbolic_shape %6868, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 1, 128)> : !torch.vtensor<[4,?,8,1,128],f16> - %int1_8741 = torch.constant.int 1 - %6869 = torch.aten.size.int %6859, %int1_8741 : !torch.vtensor<[4,?,8,128],f16>, !torch.int -> !torch.int - %int4_8742 = torch.constant.int 4 - %int8_8743 = torch.constant.int 8 - %int4_8744 = torch.constant.int 4 - %int128_8745 = torch.constant.int 128 - %6870 = torch.prim.ListConstruct %int4_8742, %6869, %int8_8743, %int4_8744, %int128_8745 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %false_8746 = torch.constant.bool false - %6871 = torch.aten.expand %6868, %6870, %false_8746 : !torch.vtensor<[4,?,8,1,128],f16>, !torch.list, !torch.bool -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6871, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int0_8747 = torch.constant.int 0 - %6872 = torch.aten.clone %6871, %int0_8747 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.int -> !torch.vtensor<[4,?,8,4,128],f16> - torch.bind_symbolic_shape %6872, [%356], affine_map<()[s0] -> (4, s0 * 32, 8, 4, 128)> : !torch.vtensor<[4,?,8,4,128],f16> - %int4_8748 = torch.constant.int 4 - %int32_8749 = torch.constant.int 32 - %int128_8750 = torch.constant.int 128 - %6873 = torch.prim.ListConstruct %int4_8748, %6869, %int32_8749, %int128_8750 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %6874 = torch.aten._unsafe_view %6872, %6873 : !torch.vtensor<[4,?,8,4,128],f16>, !torch.list -> !torch.vtensor<[4,?,32,128],f16> - torch.bind_symbolic_shape %6874, [%356], affine_map<()[s0] -> (4, s0 * 32, 32, 128)> : !torch.vtensor<[4,?,32,128],f16> - %int1_8751 = torch.constant.int 1 - %int2_8752 = torch.constant.int 2 - %6875 = torch.aten.transpose.int %6755, %int1_8751, %int2_8752 : !torch.vtensor<[4,1,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,1,128],f16> - %int1_8753 = torch.constant.int 1 - %int2_8754 = torch.constant.int 2 - %6876 = torch.aten.transpose.int %6867, %int1_8753, %int2_8754 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6876, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %int1_8755 = torch.constant.int 1 - %int2_8756 = torch.constant.int 2 - %6877 = torch.aten.transpose.int %6874, %int1_8755, %int2_8756 : !torch.vtensor<[4,?,32,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,32,?,128],f16> - torch.bind_symbolic_shape %6877, [%356], affine_map<()[s0] -> (4, 32, s0 * 32, 128)> : !torch.vtensor<[4,32,?,128],f16> - %float0.000000e00_8757 = torch.constant.float 0.000000e+00 - %false_8758 = torch.constant.bool false - %none_8759 = torch.constant.none - %6878:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%6875, %6876, %6877, %float0.000000e00_8757, %false_8758, %368, %none_8759) : (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.vtensor<[4,32,?,128],f16>, !torch.float, !torch.bool, !torch.vtensor<[4,1,1,?],f16>, !torch.none) -> (!torch.vtensor<[4,32,1,128],f16>, !torch.vtensor<[4,32,1],f32>) - %int1_8760 = torch.constant.int 1 - %int2_8761 = torch.constant.int 2 - %6879 = torch.aten.transpose.int %6878#0, %int1_8760, %int2_8761 : !torch.vtensor<[4,32,1,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[4,1,32,128],f16> - %int4_8762 = torch.constant.int 4 - %int1_8763 = torch.constant.int 1 - %int4096_8764 = torch.constant.int 4096 - %6880 = torch.prim.ListConstruct %int4_8762, %int1_8763, %int4096_8764 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6881 = torch.aten.view %6879, %6880 : !torch.vtensor<[4,1,32,128],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int-2_8765 = torch.constant.int -2 - %int-1_8766 = torch.constant.int -1 - %6882 = torch.aten.transpose.int %348, %int-2_8765, %int-1_8766 : !torch.vtensor<[4096,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,4096],f16> - %int4_8767 = torch.constant.int 4 - %int4096_8768 = torch.constant.int 4096 - %6883 = torch.prim.ListConstruct %int4_8767, %int4096_8768 : (!torch.int, !torch.int) -> !torch.list - %6884 = torch.aten.view %6881, %6883 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6885 = torch.aten.mm %6884, %6882 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_8769 = torch.constant.int 4 - %int1_8770 = torch.constant.int 1 - %int4096_8771 = torch.constant.int 4096 - %6886 = torch.prim.ListConstruct %int4_8769, %int1_8770, %int4096_8771 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6887 = torch.aten.view %6885, %6886 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_8772 = torch.constant.int 1 - %6888 = torch.aten.add.Tensor %6715, %6887, %int1_8772 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_8773 = torch.constant.int 6 - %6889 = torch.prims.convert_element_type %6888, %int6_8773 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_8774 = torch.constant.int 2 - %6890 = torch.aten.pow.Tensor_Scalar %6889, %int2_8774 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_8775 = torch.constant.int -1 - %6891 = torch.prim.ListConstruct %int-1_8775 : (!torch.int) -> !torch.list - %true_8776 = torch.constant.bool true - %none_8777 = torch.constant.none - %6892 = torch.aten.mean.dim %6890, %6891, %true_8776, %none_8777 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_8778 = torch.constant.float 9.9999997473787516E-6 - %int1_8779 = torch.constant.int 1 - %6893 = torch.aten.add.Scalar %6892, %float9.999990e-06_8778, %int1_8779 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %6894 = torch.aten.rsqrt %6893 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %6895 = torch.aten.mul.Tensor %6889, %6894 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_8780 = torch.constant.int 5 - %6896 = torch.prims.convert_element_type %6895, %int5_8780 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %6897 = torch.aten.mul.Tensor %349, %6896 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_8781 = torch.constant.int 5 - %6898 = torch.prims.convert_element_type %6897, %int5_8781 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_8782 = torch.constant.int -2 - %int-1_8783 = torch.constant.int -1 - %6899 = torch.aten.transpose.int %350, %int-2_8782, %int-1_8783 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_8784 = torch.constant.int 4 - %int4096_8785 = torch.constant.int 4096 - %6900 = torch.prim.ListConstruct %int4_8784, %int4096_8785 : (!torch.int, !torch.int) -> !torch.list - %6901 = torch.aten.view %6898, %6900 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6902 = torch.aten.mm %6901, %6899 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_8786 = torch.constant.int 4 - %int1_8787 = torch.constant.int 1 - %int14336_8788 = torch.constant.int 14336 - %6903 = torch.prim.ListConstruct %int4_8786, %int1_8787, %int14336_8788 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6904 = torch.aten.view %6902, %6903 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %6905 = torch.aten.silu %6904 : !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_8789 = torch.constant.int -2 - %int-1_8790 = torch.constant.int -1 - %6906 = torch.aten.transpose.int %351, %int-2_8789, %int-1_8790 : !torch.vtensor<[14336,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,14336],f16> - %int4_8791 = torch.constant.int 4 - %int4096_8792 = torch.constant.int 4096 - %6907 = torch.prim.ListConstruct %int4_8791, %int4096_8792 : (!torch.int, !torch.int) -> !torch.list - %6908 = torch.aten.view %6898, %6907 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6909 = torch.aten.mm %6908, %6906 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,14336],f16> -> !torch.vtensor<[4,14336],f16> - %int4_8793 = torch.constant.int 4 - %int1_8794 = torch.constant.int 1 - %int14336_8795 = torch.constant.int 14336 - %6910 = torch.prim.ListConstruct %int4_8793, %int1_8794, %int14336_8795 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6911 = torch.aten.view %6909, %6910 : !torch.vtensor<[4,14336],f16>, !torch.list -> !torch.vtensor<[4,1,14336],f16> - %6912 = torch.aten.mul.Tensor %6905, %6911 : !torch.vtensor<[4,1,14336],f16>, !torch.vtensor<[4,1,14336],f16> -> !torch.vtensor<[4,1,14336],f16> - %int-2_8796 = torch.constant.int -2 - %int-1_8797 = torch.constant.int -1 - %6913 = torch.aten.transpose.int %352, %int-2_8796, %int-1_8797 : !torch.vtensor<[4096,14336],f16>, !torch.int, !torch.int -> !torch.vtensor<[14336,4096],f16> - %int4_8798 = torch.constant.int 4 - %int14336_8799 = torch.constant.int 14336 - %6914 = torch.prim.ListConstruct %int4_8798, %int14336_8799 : (!torch.int, !torch.int) -> !torch.list - %6915 = torch.aten.view %6912, %6914 : !torch.vtensor<[4,1,14336],f16>, !torch.list -> !torch.vtensor<[4,14336],f16> - %6916 = torch.aten.mm %6915, %6913 : !torch.vtensor<[4,14336],f16>, !torch.vtensor<[14336,4096],f16> -> !torch.vtensor<[4,4096],f16> - %int4_8800 = torch.constant.int 4 - %int1_8801 = torch.constant.int 1 - %int4096_8802 = torch.constant.int 4096 - %6917 = torch.prim.ListConstruct %int4_8800, %int1_8801, %int4096_8802 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6918 = torch.aten.view %6916, %6917 : !torch.vtensor<[4,4096],f16>, !torch.list -> !torch.vtensor<[4,1,4096],f16> - %int1_8803 = torch.constant.int 1 - %6919 = torch.aten.add.Tensor %6888, %6918, %int1_8803 : !torch.vtensor<[4,1,4096],f16>, !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int6_8804 = torch.constant.int 6 - %6920 = torch.prims.convert_element_type %6919, %int6_8804 : !torch.vtensor<[4,1,4096],f16>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int2_8805 = torch.constant.int 2 - %6921 = torch.aten.pow.Tensor_Scalar %6920, %int2_8805 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f32> - %int-1_8806 = torch.constant.int -1 - %6922 = torch.prim.ListConstruct %int-1_8806 : (!torch.int) -> !torch.list - %true_8807 = torch.constant.bool true - %none_8808 = torch.constant.none - %6923 = torch.aten.mean.dim %6921, %6922, %true_8807, %none_8808 : !torch.vtensor<[4,1,4096],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[4,1,1],f32> - %float9.999990e-06_8809 = torch.constant.float 9.9999997473787516E-6 - %int1_8810 = torch.constant.int 1 - %6924 = torch.aten.add.Scalar %6923, %float9.999990e-06_8809, %int1_8810 : !torch.vtensor<[4,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[4,1,1],f32> - %6925 = torch.aten.rsqrt %6924 : !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,1],f32> - %6926 = torch.aten.mul.Tensor %6920, %6925 : !torch.vtensor<[4,1,4096],f32>, !torch.vtensor<[4,1,1],f32> -> !torch.vtensor<[4,1,4096],f32> - %int5_8811 = torch.constant.int 5 - %6927 = torch.prims.convert_element_type %6926, %int5_8811 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %6928 = torch.aten.mul.Tensor %353, %6927 : !torch.vtensor<[4096],f32>, !torch.vtensor<[4,1,4096],f16> -> !torch.vtensor<[4,1,4096],f32> - %int5_8812 = torch.constant.int 5 - %6929 = torch.prims.convert_element_type %6928, %int5_8812 : !torch.vtensor<[4,1,4096],f32>, !torch.int -> !torch.vtensor<[4,1,4096],f16> - %int-2_8813 = torch.constant.int -2 - %int-1_8814 = torch.constant.int -1 - %6930 = torch.aten.transpose.int %354, %int-2_8813, %int-1_8814 : !torch.vtensor<[128256,4096],f16>, !torch.int, !torch.int -> !torch.vtensor<[4096,128256],f16> - %int4_8815 = torch.constant.int 4 - %int4096_8816 = torch.constant.int 4096 - %6931 = torch.prim.ListConstruct %int4_8815, %int4096_8816 : (!torch.int, !torch.int) -> !torch.list - %6932 = torch.aten.view %6929, %6931 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> - %6933 = torch.aten.mm %6932, %6930 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,128256],f16> -> !torch.vtensor<[4,128256],f16> - %int4_8817 = torch.constant.int 4 - %int1_8818 = torch.constant.int 1 + %int4096_7965 = torch.constant.int 4096 + %7366 = torch.prim.ListConstruct %int4_7964, %int4096_7965 : (!torch.int, !torch.int) -> !torch.list + %7367 = torch.aten.view %7363, %7366 : !torch.vtensor<[4,1,4096],f16>, !torch.list -> !torch.vtensor<[4,4096],f16> + %7368 = torch.aten.mm %7367, %7365 : !torch.vtensor<[4,4096],f16>, !torch.vtensor<[4096,128256],f16> -> !torch.vtensor<[4,128256],f16> + %int4_7966 = torch.constant.int 4 + %int1_7967 = torch.constant.int 1 %int128256 = torch.constant.int 128256 - %6934 = torch.prim.ListConstruct %int4_8817, %int1_8818, %int128256 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %6935 = torch.aten.view %6933, %6934 : !torch.vtensor<[4,128256],f16>, !torch.list -> !torch.vtensor<[4,1,128256],f16> - return %6935 : !torch.vtensor<[4,1,128256],f16> - } - util.func private @sharktank_rotary_embedding_4_D_32_128_f32(%arg0: tensor<4x?x32x128xf32>, %arg1: tensor<4x?x128xf32>) -> tensor<4x?x32x128xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %dim = tensor.dim %arg0, %c0 : tensor<4x?x32x128xf32> - %dim_0 = tensor.dim %arg0, %c1 : tensor<4x?x32x128xf32> - %dim_1 = tensor.dim %arg0, %c2 : tensor<4x?x32x128xf32> - %dim_2 = tensor.dim %arg0, %c3 : tensor<4x?x32x128xf32> - %0 = tensor.empty(%dim, %dim_0, %dim_1, %dim_2) : tensor - %cast = tensor.cast %0 : tensor to tensor<4x?x32x128xf32> - %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<4x?x128xf32>) outs(%cast : tensor<4x?x32x128xf32>) { - ^bb0(%in: f32, %out: f32): - %2 = linalg.index 0 : index - %3 = linalg.index 1 : index - %4 = linalg.index 2 : index - %5 = linalg.index 3 : index - %6 = arith.divui %5, %c2 : index - %7 = arith.remui %5, %c2 : index - %8 = math.cos %in : f32 - %9 = math.sin %in : f32 - %10 = arith.muli %6, %c2 : index - %11 = arith.addi %10, %c1 : index - %extracted = tensor.extract %arg0[%2, %3, %4, %10] : tensor<4x?x32x128xf32> - %extracted_3 = tensor.extract %arg0[%2, %3, %4, %11] : tensor<4x?x32x128xf32> - %12 = arith.cmpi eq, %7, %c0 : index - %13 = arith.mulf %extracted, %8 : f32 - %14 = arith.mulf %extracted_3, %9 : f32 - %15 = arith.subf %13, %14 : f32 - %16 = arith.mulf %extracted_3, %8 : f32 - %17 = arith.mulf %extracted, %9 : f32 - %18 = arith.addf %16, %17 : f32 - %19 = arith.select %12, %15, %18 : f32 - linalg.yield %19 : f32 - } -> tensor<4x?x32x128xf32> - util.return %1 : tensor<4x?x32x128xf32> - } - util.func private @sharktank_rotary_embedding_4_D_8_128_f32(%arg0: tensor<4x?x8x128xf32>, %arg1: tensor<4x?x128xf32>) -> tensor<4x?x8x128xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %dim = tensor.dim %arg0, %c0 : tensor<4x?x8x128xf32> - %dim_0 = tensor.dim %arg0, %c1 : tensor<4x?x8x128xf32> - %dim_1 = tensor.dim %arg0, %c2 : tensor<4x?x8x128xf32> - %dim_2 = tensor.dim %arg0, %c3 : tensor<4x?x8x128xf32> - %0 = tensor.empty(%dim, %dim_0, %dim_1, %dim_2) : tensor - %cast = tensor.cast %0 : tensor to tensor<4x?x8x128xf32> - %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<4x?x128xf32>) outs(%cast : tensor<4x?x8x128xf32>) { - ^bb0(%in: f32, %out: f32): - %2 = linalg.index 0 : index - %3 = linalg.index 1 : index - %4 = linalg.index 2 : index - %5 = linalg.index 3 : index - %6 = arith.divui %5, %c2 : index - %7 = arith.remui %5, %c2 : index - %8 = math.cos %in : f32 - %9 = math.sin %in : f32 - %10 = arith.muli %6, %c2 : index - %11 = arith.addi %10, %c1 : index - %extracted = tensor.extract %arg0[%2, %3, %4, %10] : tensor<4x?x8x128xf32> - %extracted_3 = tensor.extract %arg0[%2, %3, %4, %11] : tensor<4x?x8x128xf32> - %12 = arith.cmpi eq, %7, %c0 : index - %13 = arith.mulf %extracted, %8 : f32 - %14 = arith.mulf %extracted_3, %9 : f32 - %15 = arith.subf %13, %14 : f32 - %16 = arith.mulf %extracted_3, %8 : f32 - %17 = arith.mulf %extracted, %9 : f32 - %18 = arith.addf %16, %17 : f32 - %19 = arith.select %12, %15, %18 : f32 - linalg.yield %19 : f32 - } -> tensor<4x?x8x128xf32> - util.return %1 : tensor<4x?x8x128xf32> - } - util.func private @sharktank_rotary_embedding_4_1_32_128_f32(%arg0: tensor<4x1x32x128xf32>, %arg1: tensor<4x1x128xf32>) -> tensor<4x1x32x128xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %dim = tensor.dim %arg0, %c0 : tensor<4x1x32x128xf32> - %dim_0 = tensor.dim %arg0, %c1 : tensor<4x1x32x128xf32> - %dim_1 = tensor.dim %arg0, %c2 : tensor<4x1x32x128xf32> - %dim_2 = tensor.dim %arg0, %c3 : tensor<4x1x32x128xf32> - %0 = tensor.empty(%dim, %dim_0, %dim_1, %dim_2) : tensor - %cast = tensor.cast %0 : tensor to tensor<4x1x32x128xf32> - %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<4x1x128xf32>) outs(%cast : tensor<4x1x32x128xf32>) { - ^bb0(%in: f32, %out: f32): - %2 = linalg.index 0 : index - %3 = linalg.index 1 : index - %4 = linalg.index 2 : index - %5 = linalg.index 3 : index - %6 = arith.divui %5, %c2 : index - %7 = arith.remui %5, %c2 : index - %8 = math.cos %in : f32 - %9 = math.sin %in : f32 - %10 = arith.muli %6, %c2 : index - %11 = arith.addi %10, %c1 : index - %extracted = tensor.extract %arg0[%2, %3, %4, %10] : tensor<4x1x32x128xf32> - %extracted_3 = tensor.extract %arg0[%2, %3, %4, %11] : tensor<4x1x32x128xf32> - %12 = arith.cmpi eq, %7, %c0 : index - %13 = arith.mulf %extracted, %8 : f32 - %14 = arith.mulf %extracted_3, %9 : f32 - %15 = arith.subf %13, %14 : f32 - %16 = arith.mulf %extracted_3, %8 : f32 - %17 = arith.mulf %extracted, %9 : f32 - %18 = arith.addf %16, %17 : f32 - %19 = arith.select %12, %15, %18 : f32 - linalg.yield %19 : f32 - } -> tensor<4x1x32x128xf32> - util.return %1 : tensor<4x1x32x128xf32> + %7369 = torch.prim.ListConstruct %int4_7966, %int1_7967, %int128256 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %7370 = torch.aten.view %7368, %7369 : !torch.vtensor<[4,128256],f16>, !torch.list -> !torch.vtensor<[4,1,128256],f16> + return %7370 : !torch.vtensor<[4,1,128256],f16> } - util.func private @sharktank_rotary_embedding_4_1_8_128_f32(%arg0: tensor<4x1x8x128xf32>, %arg1: tensor<4x1x128xf32>) -> tensor<4x1x8x128xf32> { + util.func private @paged_attention_kv_cache_gather_CACHE_SIZE_T_BLOCK_32_PART_2_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16_BATCH_PAGES_i64__i64__i64_BATCH_PAGES_HEAD_COUNT_KV_8_BLOCK_SEQ_STRIDE_32_ATTN_HEAD_DIM_128_f16(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %dim = tensor.dim %arg0, %c0 : tensor<4x1x8x128xf32> - %dim_0 = tensor.dim %arg0, %c1 : tensor<4x1x8x128xf32> - %dim_1 = tensor.dim %arg0, %c2 : tensor<4x1x8x128xf32> - %dim_2 = tensor.dim %arg0, %c3 : tensor<4x1x8x128xf32> - %0 = tensor.empty(%dim, %dim_0, %dim_1, %dim_2) : tensor - %cast = tensor.cast %0 : tensor to tensor<4x1x8x128xf32> - %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<4x1x128xf32>) outs(%cast : tensor<4x1x8x128xf32>) { - ^bb0(%in: f32, %out: f32): - %2 = linalg.index 0 : index - %3 = linalg.index 1 : index - %4 = linalg.index 2 : index - %5 = linalg.index 3 : index - %6 = arith.divui %5, %c2 : index - %7 = arith.remui %5, %c2 : index - %8 = math.cos %in : f32 - %9 = math.sin %in : f32 - %10 = arith.muli %6, %c2 : index - %11 = arith.addi %10, %c1 : index - %extracted = tensor.extract %arg0[%2, %3, %4, %10] : tensor<4x1x8x128xf32> - %extracted_3 = tensor.extract %arg0[%2, %3, %4, %11] : tensor<4x1x8x128xf32> - %12 = arith.cmpi eq, %7, %c0 : index - %13 = arith.mulf %extracted, %8 : f32 - %14 = arith.mulf %extracted_3, %9 : f32 - %15 = arith.subf %13, %14 : f32 - %16 = arith.mulf %extracted_3, %8 : f32 - %17 = arith.mulf %extracted, %9 : f32 - %18 = arith.addf %16, %17 : f32 - %19 = arith.select %12, %15, %18 : f32 - linalg.yield %19 : f32 - } -> tensor<4x1x8x128xf32> - util.return %1 : tensor<4x1x8x128xf32> + %extracted = tensor.extract %arg2[] : tensor + %extracted_0 = tensor.extract %arg3[] : tensor + %0 = arith.index_cast %extracted : i64 to index + %1 = arith.index_cast %extracted_0 : i64 to index + %dim = tensor.dim %arg0, %c0 : tensor + %dim_1 = tensor.dim %arg1, %c0 : tensor + %dim_2 = tensor.dim %arg1, %c1 : tensor + %extracted_slice = tensor.extract_slice %arg0[0, %0, %1, 0, 0, 0] [%dim, 1, 1, 8, 32, 128] [1, 1, 1, 1, 1, 1] : tensor to tensor + %2 = tensor.empty(%dim_1, %dim_2) : tensor + %3 = iree_linalg_ext.gather dimension_map = [0] ins(%extracted_slice, %arg1 : tensor, tensor) outs(%2 : tensor) -> tensor + util.return %3 : tensor } }